diff --git a/.dockerignore b/.dockerignore index 2eea525d..8e571199 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,62 @@ -.env \ No newline at end of file +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# Virtual environments +venv/ +ENV/ +env/ +.venv + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Git +.git/ +.gitignore + +# Documentation +*.md +docs/ + +# Tests +test/ +tests/ +pytest_cache/ +.coverage +.pytest_cache/ + +# Development files +.env +.env.local +*.log + +# Build artifacts +build/ +dist/ +*.egg-info/ + +# Docker files (don't copy themselves) +Dockerfile* +docker-compose*.yml +.dockerignore + +# Bot data that should be mounted as volumes +bots/instances/* +bots/data/* +bots/credentials/* +!bots/credentials/master_account/ + +# Archives +bots/archived/ \ No newline at end of file diff --git a/.github/workflows/docker_buildx_workflow.yml b/.github/workflows/docker_buildx_workflow.yml index 156388d9..7ba8f9c6 100644 --- a/.github/workflows/docker_buildx_workflow.yml +++ b/.github/workflows/docker_buildx_workflow.yml @@ -1,4 +1,4 @@ -name: Backend-API Docker Buildx Workflow +name: Hummingbot-API Docker Buildx Workflow on: pull_request: @@ -36,7 +36,7 @@ jobs: context: . platforms: linux/amd64,linux/arm64 push: true - tags: hummingbot/backend-api:development + tags: hummingbot/hummingbot-api:development - name: Build and push Latest Image if: github.base_ref == 'main' @@ -46,7 +46,7 @@ jobs: file: ./Dockerfile platforms: linux/amd64,linux/arm64 push: true - tags: hummingbot/backend-api:latest + tags: hummingbot/hummingbot-api:latest build_release: if: github.event_name == 'release' @@ -77,4 +77,4 @@ jobs: context: . platforms: linux/amd64,linux/arm64 push: true - tags: hummingbot/backend-api:${{ steps.get_tag.outputs.VERSION }} + tags: hummingbot/hummingbot-api:${{ steps.get_tag.outputs.VERSION }} diff --git a/Dockerfile b/Dockerfile index 8b310580..e1766153 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,22 +1,56 @@ -# Start from a base image with Miniconda installed -FROM continuumio/miniconda3 +# Stage 1: Builder stage +FROM continuumio/miniconda3 AS builder -# Install system dependencies +# Install build dependencies RUN apt-get update && \ - apt-get install -y sudo libusb-1.0 python3-dev gcc && \ + apt-get install -y python3-dev gcc && \ rm -rf /var/lib/apt/lists/* -# Set the working directory in the container -WORKDIR /backend-api +# Set working directory +WORKDIR /build + +# Copy only the environment file first (for better layer caching) +COPY environment.yml . + +# Create the conda environment +RUN conda env create -f environment.yml && \ + conda clean -afy && \ + rm -rf /root/.cache/pip/* + +# Stage 2: Runtime stage +FROM continuumio/miniconda3 + +# Install only runtime dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + libusb-1.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Copy the conda environment from builder +COPY --from=builder /opt/conda/envs/hummingbot-api /opt/conda/envs/hummingbot-api + +# Set the working directory +WORKDIR /hummingbot-api + +# Copy only necessary application files +COPY main.py config.py deps.py ./ +COPY models ./models +COPY routers ./routers +COPY services ./services +COPY utils ./utils +COPY database ./database +COPY bots/controllers ./bots/controllers +COPY bots/scripts ./bots/scripts -# Copy the current directory contents and the Conda environment file into the container -COPY . . +# Create necessary directories +RUN mkdir -p bots/instances bots/conf bots/credentials bots/data bots/archived -# Create the environment from the environment.yml file -RUN conda env create -f environment.yml +# Expose port +EXPOSE 8000 -# Make RUN commands use the new environment -SHELL ["conda", "run", "-n", "backend-api", "/bin/bash", "-c"] +# Set environment variables to ensure conda env is used +ENV PATH="/opt/conda/envs/hummingbot-api/bin:$PATH" +ENV CONDA_DEFAULT_ENV=hummingbot-api -# The code to run when container is started -ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "backend-api", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] +# Run the application +ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/Makefile b/Makefile index b608e236..db10e0cb 100644 --- a/Makefile +++ b/Makefile @@ -5,8 +5,8 @@ .PHONY: uninstall .PHONY: install .PHONY: install-pre-commit -.PHONY: docker_build -.PHONY: docker_run +.PHONY: build +.PHONY: deploy detect_conda_bin := $(shell bash -c 'if [ "${CONDA_EXE} " == " " ]; then \ @@ -26,25 +26,25 @@ run: uvicorn main:app --reload uninstall: - conda env remove -n backend-api -y + conda env remove -n hummingbot-api -y install: - if conda env list | grep -q '^backend-api '; then \ + if conda env list | grep -q '^hummingbot-api '; then \ echo "Environment already exists."; \ else \ conda env create -f environment.yml; \ fi - conda activate backend-api + conda activate hummingbot-api $(MAKE) install-pre-commit install-pre-commit: - /bin/bash -c 'source "${CONDA_BIN}/activate" backend-api && \ + /bin/bash -c 'source "${CONDA_BIN}/activate" hummingbot-api && \ if ! conda list pre-commit | grep pre-commit &> /dev/null; then \ pip install pre-commit; \ fi && pre-commit install' -docker_build: - docker build -t hummingbot/backend-api:latest . +build: + docker build -t hummingbot/hummingbot-api:latest . -docker_run: +deploy: docker compose up -d diff --git a/README.md b/README.md index 7d23d29b..5e5bd8dd 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,271 @@ -# Backend API +# Hummingbot API -## Overview -Backend-api is a dedicated solution for managing Hummingbot instances. It offers a robust backend API to streamline the deployment, management, and interaction with Hummingbot containers. This tool is essential for administrators and developers looking to efficiently handle various aspects of Hummingbot operations. +A comprehensive RESTful API framework for managing trading operations across multiple exchanges. The Hummingbot API provides a centralized platform to aggregate all your trading functionalities, from basic account management to sophisticated automated trading strategies. -## Features -- **Deployment File Management**: Manage files necessary for deploying new Hummingbot instances. -- **Container Control**: Effortlessly start and stop Hummingbot containers. -- **Archiving Options**: Securely archive containers either locally or on Amazon S3 post-removal. -- **Direct Messaging**: Communicate with Hummingbots through the broker for effective control and coordination. +## What is Hummingbot API? -## Getting Started +The Hummingbot API is designed to be your central hub for trading operations, offering: + +- **Multi-Exchange Account Management**: Create and manage multiple trading accounts across different exchanges +- **Portfolio Monitoring**: Real-time balance tracking and portfolio distribution analysis +- **Trade Execution**: Execute trades, manage orders, and monitor positions across all your accounts +- **Automated Trading**: Deploy and control Hummingbot instances with automated strategies +- **Strategy Management**: Add, configure, and manage trading strategies in real-time +- **Complete Flexibility**: Build any trading product on top of this robust API framework + +Whether you're building a trading dashboard, implementing algorithmic strategies, or creating a comprehensive trading platform, the Hummingbot API provides all the tools you need. + +## System Dependencies + +The Hummingbot API requires two essential services to function properly: + +### 1. PostgreSQL Database +Stores all trading data including: +- Orders and trade history +- Account states and balances +- Positions and funding payments +- Performance metrics + +### 2. EMQX Message Broker +Enables real-time communication with trading bots: +- Receives live updates from running bots +- Sends commands to control bot execution +- Handles real-time data streaming -### Conda Installation -1. Install the environment using Conda: +## Installation & Setup + +### Prerequisites +- Docker and Docker Compose installed +- Git for cloning the repository + +### Quick Start + +1. **Clone the repository** ```bash - conda env create -f environment.yml + git clone https://github.com/hummingbot/hummingbot-api.git + cd hummingbot-api ``` -2. Activate the Conda environment: + +2. **Make setup script executable and run it** ```bash - conda activate backend-api + chmod +x setup.sh + ./setup.sh ``` -### Running the API with Conda -Run the API using uvicorn with the following command: +3. **Configure your environment** + During setup, you'll configure several important variables: + + - **Config Password**: Used to encrypt and hash API keys and credentials for security + - **Username & Password**: Basic authentication credentials for API access (used by dashboards and other systems) + - **Additional configurations**: Available in the `.env` file including: + - Broker configuration (EMQX settings) + - Database URL + - Market data cleanup settings + - AWS S3 configuration (experimental) + - Banned tokens list (for delisted tokens) + +4. **Set up monitoring (Production recommended)** + For production deployments, add observability through Logfire: ```bash - uvicorn main:app --reload + export LOGFIRE_TOKEN=your_token_here ``` + Learn more: [Logfire Documentation](https://logfire.pydantic.dev/docs/) + +After running `setup.sh`, the required Docker images (EMQX, PostgreSQL, and Hummingbot) will be running and ready. + +## Running the API + +You have two deployment options depending on your use case: + +### For Users (Production/Simple Deployment) +```bash +./run.sh +``` +This runs the API in a Docker container - simple and isolated. + +### For Developers (Development Environment) +1. **Install Conda** (if not already installed) +2. **Set up the development environment** + ```bash + make install + ``` + This creates a Conda environment with all dependencies. + +3. **Run in development mode** + ```bash + ./run.sh --dev + ``` + This starts the API from source with hot-reloading enabled. + +## Getting Started + +Once the API is running, you can access it at `http://localhost:8000` + +### First Steps +1. **Visit the API Documentation**: Go to `http://localhost:8000/docs` to explore the interactive Swagger documentation +2. **Authenticate**: Use the username and password you configured during setup +3. **Test endpoints**: Use the Swagger interface to test API functionality + +## API Overview + +The Hummingbot API is organized into several functional routers: + +### 🐳 Docker Management (`/docker`) +- Check Docker daemon status and health +- Pull new Docker images with async support +- Start, stop, and remove containers +- Monitor active and exited containers +- Clean up exited containers +- Archive container data locally or to S3 +- Track image pull status and progress + +### 💳 Account Management (`/accounts`) +- Create and delete trading accounts +- Add/remove exchange credentials +- List available credentials per account +- Basic account configuration + +### 🔌 Connector Discovery (`/connectors`) +**Provides exchange connector information and configuration** +- List available exchange connectors +- Get connector configuration requirements +- Retrieve trading rules and constraints +- Query supported order types per connector + +### 📊 Portfolio Management (`/portfolio`) +**Centralized portfolio tracking and analytics** +- **Real-time Portfolio State**: Current balances across all accounts +- **Portfolio History**: Time-series data with cursor-based pagination +- **Token Distribution**: Aggregate holdings by token across exchanges +- **Account Distribution**: Percentage-based portfolio allocation analysis +- **Advanced Filtering**: Filter by account names and connectors + +### 💹 Trading Operations (`/trading`) +**Enhanced with POST-based filtering and comprehensive order/trade management** +- **Order Placement**: Execute trades with advanced order types +- **Order Cancellation**: Cancel specific orders by ID +- **Position Tracking**: Real-time perpetual positions with PnL data +- **Active Orders**: Live order monitoring from connector in-flight orders +- **Order History**: Paginated historical orders with advanced filtering +- **Trade History**: Complete execution records with filtering +- **Funding Payments**: Historical funding payment tracking for perpetuals +- **Position Modes**: Configure HEDGE/ONEWAY modes for perpetual trading +- **Leverage Management**: Set and adjust leverage per trading pair + +### 🤖 Bot Orchestration (`/bot-orchestration`) +- Monitor bot status and MQTT connectivity +- Deploy V2 scripts and controllers +- Start/stop bots with configurable parameters +- Stop and archive bots with background task support +- Retrieve bot performance history +- Real-time bot status monitoring + +### 📋 Strategy Management +- **Controllers** (`/controllers`): Manage V2 strategy controllers + - CRUD operations on controller files + - Controller configuration management + - Bot-specific controller configurations + - Template retrieval for new configs +- **Scripts** (`/scripts`): Handle traditional Hummingbot scripts + - CRUD operations on script files + - Script configuration management + - Configuration templates + +### 📊 Market Data (`/market-data`) +**Professional market data analysis and real-time feeds** +- **Price Discovery**: Real-time prices, funding rates, mark/index prices +- **Candle Data**: Real-time and historical candles with multiple intervals +- **Order Book Analysis**: + - Live order book snapshots + - Price impact calculations + - Volume queries at specific price levels + - VWAP (Volume-Weighted Average Price) calculations +- **Feed Management**: Active feed monitoring with automatic cleanup + +### 🔄 Backtesting (`/backtesting`) +- Run strategy backtests against historical data +- Support for controller configurations +- Customizable trade costs and resolution + +### 📈 Archived Bot Analytics (`/archived-bots`) +**Comprehensive analysis of stopped bot performance** +- List and discover archived bot databases +- Performance metrics and trade analysis +- Historical order and trade retrieval +- Position and executor data extraction +- Controller configuration recovery +- Support for both V1 and V2 bot architectures + +## Configuration + +### Environment Variables +Key configuration options available in `.env`: + +- **CONFIG_PASSWORD**: Encrypts API keys and credentials +- **USERNAME/PASSWORD**: API authentication credentials +- **BROKER_HOST/PORT**: EMQX message broker settings +- **DATABASE_URL**: PostgreSQL connection string +- **ACCOUNT_UPDATE_INTERVAL**: Balance update frequency (minutes) +- **AWS_API_KEY/AWS_SECRET_KEY**: S3 archiving (optional) +- **BANNED_TOKENS**: Comma-separated list of tokens to exclude +- **LOGFIRE_TOKEN**: Observability and monitoring (production) + +### Bot Instance Structure +Each bot maintains its own isolated environment: +``` +bots/instances/hummingbot-{name}/ +├── conf/ # Configuration files +├── data/ # Bot databases and state +└── logs/ # Execution logs +``` + +## Development + +### Code Quality Tools +```bash +# Install pre-commit hooks +make install-pre-commit + +# Format code (runs automatically) +black --line-length 130 . +isort --line-length 130 --profile black . +``` + +### Testing +The API includes comprehensive backtesting capabilities. Test using: +- Backtesting router for strategy validation +- Swagger UI at `http://localhost:8000/docs` +- Integration testing with live containers + +## Architecture + +### Core Components +1. **FastAPI Application**: HTTP API with Basic Auth +2. **Docker Service**: Container lifecycle management +3. **Bot Orchestrator**: Strategy deployment and monitoring +4. **Accounts Service**: Multi-exchange account management +5. **Market Data Manager**: Real-time feeds and historical data +6. **MQTT Broker**: Real-time bot communication -### Docker Installation and Running the API -For running the project using Docker, follow these steps: +### Data Models +- Orders and trades with multi-account support +- Portfolio states and balance tracking +- Position management for perpetual trading +- Historical performance analytics -1. **Set up Environment Variables**: - - Execute the `set_environment.sh` script to configure the necessary environment variables in the `.env` file: - ```bash - ./set_environment.sh - ``` +## Authentication -2. **Build and Run with Docker Compose**: - - After setting up the environment variables, use Docker Compose to build and run the project: - ```bash - docker compose up --build - ``` +All API endpoints require HTTP Basic Authentication. Include your configured credentials in all requests: - - This command will build the Docker image and start the containers as defined in your `docker-compose.yml` file. +```bash +curl -u username:password http://localhost:8000/endpoint +``` -### Usage -This API is designed for: -- **Deploying Hummingbot instances** -- **Starting/Stopping Containers** -- **Archiving Hummingbots** -- **Messaging with Hummingbot instances** +## Support & Documentation -To test these endpoints, you can use the [Swagger UI](http://localhost:8000/docs) or [Redoc](http://localhost:8000/redoc). +- **API Documentation**: Available at `http://localhost:8000/docs` when running +- **Detailed Examples**: Check the `CLAUDE.md` file for comprehensive API usage examples +- **Issues**: Report bugs and feature requests through the project's issue tracker +--- -## Contributing -Contributions are welcome! For support or queries, please contact us on Discord. +Ready to start trading? Deploy your first account and start exploring the powerful capabilities of the Hummingbot API! \ No newline at end of file diff --git a/bots/controllers/directional_trading/ai_livestream.py b/bots/controllers/directional_trading/ai_livestream.py deleted file mode 100644 index 6cef9cfa..00000000 --- a/bots/controllers/directional_trading/ai_livestream.py +++ /dev/null @@ -1,86 +0,0 @@ -from decimal import Decimal -from typing import List - -import pandas_ta as ta # noqa: F401 -from pydantic import Field - -from hummingbot.core.data_type.common import TradeType -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig -from hummingbot.remote_iface.mqtt import ExternalTopicFactory -from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( - DirectionalTradingControllerBase, - DirectionalTradingControllerConfigBase, -) -from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig - - -class AILivestreamControllerConfig(DirectionalTradingControllerConfigBase): - controller_name: str = "ai_livestream" - candles_config: List[CandlesConfig] = [] - long_threshold: float = Field(default=0.5, json_schema_extra={"is_updatable": True}) - short_threshold: float = Field(default=0.5, json_schema_extra={"is_updatable": True}) - topic: str = "hbot/predictions" - - -class AILivestreamController(DirectionalTradingControllerBase): - def __init__(self, config: AILivestreamControllerConfig, *args, **kwargs): - self.config = config - super().__init__(config, *args, **kwargs) - # Start ML signal listener - self._init_ml_signal_listener() - - def _init_ml_signal_listener(self): - """Initialize a listener for ML signals from the MQTT broker""" - try: - normalized_pair = self.config.trading_pair.replace("-", "_").lower() - topic = f"{self.config.topic}/{normalized_pair}/ML_SIGNALS" - self._ml_signal_listener = ExternalTopicFactory.create_async( - topic=topic, - callback=self._handle_ml_signal, - use_bot_prefix=False, - ) - self.logger().info("ML signal listener initialized successfully") - except Exception as e: - self.logger().error(f"Failed to initialize ML signal listener: {str(e)}") - self._ml_signal_listener = None - - def _handle_ml_signal(self, signal: dict, topic: str): - """Handle incoming ML signal""" - # self.logger().info(f"Received ML signal: {signal}") - short, neutral, long = signal["probabilities"] - if short > self.config.short_threshold: - self.processed_data["signal"] = -1 - elif long > self.config.long_threshold: - self.processed_data["signal"] = 1 - else: - self.processed_data["signal"] = 0 - self.processed_data["features"] = signal - - async def update_processed_data(self): - pass - - def get_executor_config(self, trade_type: TradeType, price: Decimal, amount: Decimal): - """ - Get the executor config based on the trade_type, price and amount. This method can be overridden by the - subclasses if required. - """ - return PositionExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.connector_name, - trading_pair=self.config.trading_pair, - side=trade_type, - entry_price=price, - amount=amount, - triple_barrier_config=self.config.triple_barrier_config.new_instance_with_adjusted_volatility( - volatility_factor=self.processed_data["features"].get("target_pct", 0.01)), - leverage=self.config.leverage, - ) - - def to_format_status(self) -> List[str]: - lines = [] - features = self.processed_data.get("features", {}) - lines.append(f"Signal: {self.processed_data.get('signal', 'N/A')}") - lines.append(f"Timestamp: {features.get('timestamp', 'N/A')}") - lines.append(f"Probabilities: {features.get('probabilities', 'N/A')}") - lines.append(f"Target Pct: {features.get('target_pct', 'N/A')}") - return lines diff --git a/bots/controllers/directional_trading/dman_v3.py b/bots/controllers/directional_trading/dman_v3.py index ca648d76..8e4ee07e 100644 --- a/bots/controllers/directional_trading/dman_v3.py +++ b/bots/controllers/directional_trading/dman_v3.py @@ -3,6 +3,9 @@ from typing import List, Optional, Tuple import pandas_ta as ta # noqa: F401 +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.core.data_type.common import TradeType from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( @@ -11,8 +14,6 @@ ) from hummingbot.strategy_v2.executors.dca_executor.data_types import DCAExecutorConfig, DCAMode from hummingbot.strategy_v2.executors.position_executor.data_types import TrailingStop -from pydantic import Field, field_validator -from pydantic_core.core_schema import ValidationInfo class DManV3ControllerConfig(DirectionalTradingControllerConfigBase): diff --git a/bots/controllers/generic/arbitrage_controller.py b/bots/controllers/generic/arbitrage_controller.py index ff8f6517..825a8663 100644 --- a/bots/controllers/generic/arbitrage_controller.py +++ b/bots/controllers/generic/arbitrage_controller.py @@ -1,9 +1,10 @@ from decimal import Decimal -from typing import Dict, List, Set +from typing import List import pandas as pd from hummingbot.client.ui.interface_utils import format_df_for_printout +from hummingbot.core.data_type.common import MarketDict from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase from hummingbot.strategy_v2.executors.arbitrage_executor.data_types import ArbitrageExecutorConfig @@ -23,18 +24,8 @@ class ArbitrageControllerConfig(ControllerConfigBase): rate_connector: str = "binance" quote_conversion_asset: str = "USDT" - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.exchange_pair_1.connector_name == self.exchange_pair_2.connector_name: - markets.update({ - self.exchange_pair_1.connector_name: {self.exchange_pair_1.trading_pair, - self.exchange_pair_2.trading_pair} - }) - else: - markets.update({ - self.exchange_pair_1.connector_name: {self.exchange_pair_1.trading_pair}, - self.exchange_pair_2.connector_name: {self.exchange_pair_2.trading_pair} - }) - return markets + def update_markets(self, markets: MarketDict) -> MarketDict: + return [markets.add_or_update(cp.connector_name, cp.trading_pair) for cp in [self.exchange_pair_1, self.exchange_pair_2]][-1] class ArbitrageController(ControllerBase): diff --git a/bots/controllers/generic/basic_order_example.py b/bots/controllers/generic/basic_order_example.py deleted file mode 100644 index 10368da4..00000000 --- a/bots/controllers/generic/basic_order_example.py +++ /dev/null @@ -1,55 +0,0 @@ -from decimal import Decimal -from typing import Dict, Set - -from hummingbot.core.data_type.common import PositionMode, PriceType, TradeType -from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase -from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig -from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction - - -class BasicOrderExampleConfig(ControllerConfigBase): - controller_name: str = "basic_order_example" - controller_type: str = "generic" - connector_name: str = "binance_perpetual" - trading_pair: str = "WLD-USDT" - side: TradeType = TradeType.BUY - position_mode: PositionMode = PositionMode.HEDGE - leverage: int = 50 - amount_quote: Decimal = Decimal("10") - order_frequency: int = 10 - - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.connector_name not in markets: - markets[self.connector_name] = set() - markets[self.connector_name].add(self.trading_pair) - return markets - - -class BasicOrderExample(ControllerBase): - def __init__(self, config: BasicOrderExampleConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.config = config - self.last_timestamp = 0 - - def determine_executor_actions(self) -> list[ExecutorAction]: - if (self.processed_data["n_active_executors"] == 0 and - self.market_data_provider.time() - self.last_timestamp > self.config.order_frequency): - self.last_timestamp = self.market_data_provider.time() - config = OrderExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.connector_name, - trading_pair=self.config.trading_pair, - side=self.config.side, - amount=self.config.amount_quote / self.processed_data["mid_price"], - execution_strategy=ExecutionStrategy.MARKET, - price=self.processed_data["mid_price"], - ) - return [CreateExecutorAction( - controller_id=self.config.id, - executor_config=config)] - return [] - - async def update_processed_data(self): - mid_price = self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) - n_active_executors = len([executor for executor in self.executors_info if executor.is_active]) - self.processed_data = {"mid_price": mid_price, "n_active_executors": n_active_executors} diff --git a/bots/controllers/generic/basic_order_open_close_example.py b/bots/controllers/generic/basic_order_open_close_example.py deleted file mode 100644 index bfeef02d..00000000 --- a/bots/controllers/generic/basic_order_open_close_example.py +++ /dev/null @@ -1,87 +0,0 @@ -from decimal import Decimal -from typing import Dict, Set - -from hummingbot.core.data_type.common import PositionAction, PositionMode, PriceType, TradeType -from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase -from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig -from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction - - -class BasicOrderOpenCloseExampleConfig(ControllerConfigBase): - controller_name: str = "basic_order_open_close_example" - controller_type: str = "generic" - connector_name: str = "binance_perpetual" - trading_pair: str = "WLD-USDT" - side: TradeType = TradeType.BUY - position_mode: PositionMode = PositionMode.HEDGE - leverage: int = 50 - close_order_delay: int = 10 - open_short_to_close_long: bool = False - close_partial_position: bool = False - amount_quote: Decimal = Decimal("20") - - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.connector_name not in markets: - markets[self.connector_name] = set() - markets[self.connector_name].add(self.trading_pair) - return markets - - -class BasicOrderOpenClose(ControllerBase): - def __init__(self, config: BasicOrderOpenCloseExampleConfig, *args, **kwargs): - super().__init__(config, *args, **kwargs) - self.config = config - self.open_order_placed = False - self.closed_order_placed = False - self.last_timestamp = 0 - self.open_side = self.config.side - self.close_side = TradeType.SELL if self.config.side == TradeType.BUY else TradeType.BUY - - def get_position(self, connector_name, trading_pair): - for position in self.positions_held: - if position.connector_name == connector_name and position.trading_pair == trading_pair: - return position - - def determine_executor_actions(self) -> list[ExecutorAction]: - mid_price = self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) - if not self.open_order_placed: - config = OrderExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.connector_name, - trading_pair=self.config.trading_pair, - side=self.config.side, - amount=self.config.amount_quote / mid_price, - execution_strategy=ExecutionStrategy.MARKET, - position_action=PositionAction.OPEN, - price=mid_price, - ) - self.open_order_placed = True - self.last_timestamp = self.market_data_provider.time() - return [CreateExecutorAction( - controller_id=self.config.id, - executor_config=config)] - else: - if self.market_data_provider.time() - self.last_timestamp > self.config.close_order_delay and not self.closed_order_placed: - current_position = self.get_position(self.config.connector_name, self.config.trading_pair) - if current_position is None: - self.logger().info("The original position is not found, can close the position") - else: - amount = current_position.amount / 2 if self.config.close_partial_position else current_position.amount - config = OrderExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.connector_name, - trading_pair=self.config.trading_pair, - side=self.close_side, - amount=amount, - execution_strategy=ExecutionStrategy.MARKET, - position_action=PositionAction.OPEN if self.config.open_short_to_close_long else PositionAction.CLOSE, - price=mid_price, - ) - self.closed_order_placed = True - return [CreateExecutorAction( - controller_id=self.config.id, - executor_config=config)] - return [] - - async def update_processed_data(self): - pass diff --git a/bots/controllers/generic/grid_strike.py b/bots/controllers/generic/grid_strike.py index a45b83c7..825082c4 100644 --- a/bots/controllers/generic/grid_strike.py +++ b/bots/controllers/generic/grid_strike.py @@ -1,7 +1,9 @@ from decimal import Decimal -from typing import Dict, List, Optional, Set +from typing import List, Optional -from hummingbot.core.data_type.common import OrderType, PositionMode, PriceType, TradeType +from pydantic import Field + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase from hummingbot.strategy_v2.executors.data_types import ConnectorPair @@ -9,7 +11,6 @@ from hummingbot.strategy_v2.executors.position_executor.data_types import TripleBarrierConfig from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction from hummingbot.strategy_v2.models.executors_info import ExecutorInfo -from pydantic import Field class GridStrikeConfig(ControllerConfigBase): @@ -51,11 +52,8 @@ class GridStrikeConfig(ControllerConfigBase): take_profit_order_type=OrderType.LIMIT_MAKER, ) - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.connector_name not in markets: - markets[self.connector_name] = set() - markets[self.connector_name].add(self.trading_pair) - return markets + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) class GridStrike(ControllerBase): diff --git a/bots/controllers/generic/pmm.py b/bots/controllers/generic/pmm.py index 7a66baf9..97e55135 100644 --- a/bots/controllers/generic/pmm.py +++ b/bots/controllers/generic/pmm.py @@ -1,11 +1,10 @@ from decimal import Decimal -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union from pydantic import Field, field_validator from pydantic_core.core_schema import ValidationInfo -from hummingbot.core.data_type.common import OrderType, PositionMode, PriceType, TradeType -from hummingbot.core.data_type.trade_fee import TokenAmount +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase from hummingbot.strategy_v2.executors.data_types import ConnectorPair @@ -40,7 +39,7 @@ class PMMConfig(ControllerConfigBase): default=Decimal("0.05"), json_schema_extra={ "prompt_on_new": True, - "prompt": "Enter the portfolio allocation (e.g., 0.05 for 5%):", + "prompt": "Enter the maximum quote exposure percentage around mid price (e.g., 0.05 for 5% of total quote allocation):", } ) target_base_pct: Decimal = Field( @@ -136,6 +135,7 @@ class PMMConfig(ControllerConfigBase): } ) global_take_profit: Decimal = Decimal("0.02") + global_stop_loss: Decimal = Decimal("0.05") @field_validator("take_profit", mode="before") @classmethod @@ -234,11 +234,8 @@ def get_spreads_and_amounts_in_quote(self, trade_type: TradeType) -> Tuple[List[ spreads = getattr(self, f'{trade_type.name.lower()}_spreads') return spreads, [amt_pct * self.total_amount_quote * self.portfolio_allocation for amt_pct in normalized_amounts_pct] - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.connector_name not in markets: - markets[self.connector_name] = set() - markets[self.connector_name].add(self.trading_pair) - return markets + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) class PMM(ControllerBase): @@ -266,8 +263,20 @@ def create_actions_proposal(self) -> List[ExecutorAction]: Create actions proposal based on the current state of the controller. """ create_actions = [] - if self.processed_data["current_base_pct"] > self.config.target_base_pct and self.processed_data["unrealized_pnl_pct"] > self.config.global_take_profit: - # Create a global take profit executor + + # Check if a position reduction executor for TP/SL is already sent + reduction_executor_exists = any( + executor.is_active and + executor.custom_info.get("level_id") == "global_tp_sl" + for executor in self.executors_info + ) + + if (not reduction_executor_exists and + self.processed_data["current_base_pct"] > self.config.target_base_pct and + (self.processed_data["unrealized_pnl_pct"] > self.config.global_take_profit or + self.processed_data["unrealized_pnl_pct"] < -self.config.global_stop_loss)): + + # Create a global take profit or stop loss executor create_actions.append(CreateExecutorAction( controller_id=self.config.id, executor_config=OrderExecutorConfig( @@ -278,6 +287,7 @@ def create_actions_proposal(self) -> List[ExecutorAction]: amount=self.processed_data["position_amount"], execution_strategy=ExecutionStrategy.MARKET, price=self.processed_data["reference_price"], + level_id="global_tp_sl" # Use a specific level_id to identify this as a TP/SL executor ) )) return create_actions @@ -440,81 +450,198 @@ def get_not_active_levels_ids(self, active_levels_ids: List[str]) -> List[str]: return sell_ids_missing return buy_ids_missing + sell_ids_missing - def get_balance_requirements(self) -> List[TokenAmount]: - """ - Get the balance requirements for the controller. - """ - base_asset, quote_asset = self.config.trading_pair.split("-") - _, amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.BUY) - _, amounts_base = self.config.get_spreads_and_amounts_in_quote(TradeType.SELL) - return [TokenAmount(base_asset, Decimal(sum(amounts_base) / self.processed_data["reference_price"])), - TokenAmount(quote_asset, Decimal(sum(amounts_quote)))] - def to_format_status(self) -> List[str]: """ Get the status of the controller in a formatted way with ASCII visualizations. """ + from decimal import Decimal + from itertools import zip_longest + status = [] - status.append(f"Controller ID: {self.config.id}") - status.append(f"Connector: {self.config.connector_name}") - status.append(f"Trading Pair: {self.config.trading_pair}") - status.append(f"Portfolio Allocation: {self.config.portfolio_allocation}") - status.append(f"Reference Price: {self.processed_data['reference_price']}") - status.append(f"Spread Multiplier: {self.processed_data['spread_multiplier']}") - - # Base percentage visualization + + # Get all required data base_pct = self.processed_data['current_base_pct'] min_pct = self.config.min_base_pct max_pct = self.config.max_base_pct target_pct = self.config.target_base_pct - # Create base percentage bar - bar_width = 50 + skew = base_pct - target_pct + skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') + max_skew = getattr(self.config, 'max_skew', Decimal('0.0')) + + # Fixed widths - adjusted based on screenshot analysis + outer_width = 92 # Total width including outer borders + inner_width = outer_width - 4 # Inner content width + half_width = (inner_width) // 2 - 1 # Width of each column in split sections + bar_width = inner_width - 15 # Width of visualization bars (accounting for label) + + # Header - omit ID since it's shown above in controller header + status.append("╒" + "═" * (inner_width) + "╕") + + header_line = ( + f"{self.config.connector_name}:{self.config.trading_pair} " + f"Price: {self.processed_data['reference_price']} " + f"Alloc: {self.config.portfolio_allocation:.1%} " + f"Spread Mult: {self.processed_data['spread_multiplier']} |" + ) + + status.append(f"│ {header_line:<{inner_width}} │") + + # Position and PnL sections with precise widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'POSITION STATUS':<{half_width - 2}} │ {'PROFIT & LOSS':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Position data for left column + position_info = [ + f"Current: {base_pct:.2%}", + f"Target: {target_pct:.2%}", + f"Min/Max: {min_pct:.2%}/{max_pct:.2%}", + f"Skew: {skew_pct:+.2%} (max {max_skew:.2%})" + ] + + # PnL data for right column + pnl_info = [] + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + pnl_sign = "+" if pnl >= 0 else "" + pnl_info = [ + f"Unrealized: {pnl_sign}{pnl:.2%}", + f"Take Profit: {self.config.global_take_profit:.2%}", + f"Stop Loss: {-self.config.global_stop_loss:.2%}", + f"Leverage: {self.config.leverage}x" + ] + + # Display position and PnL info side by side with exact spacing + for pos_line, pnl_line in zip_longest(position_info, pnl_info, fillvalue=""): + status.append(f"│ {pos_line:<{half_width - 2}} │ {pnl_line:<{half_width - 2}} │") + + # Adjust visualization section - ensure consistent spacing + status.append(f"├{'─' * (inner_width)}┤") + status.append(f"│ {'VISUALIZATIONS':<{inner_width}} │") + status.append(f"├{'─' * (inner_width)}┤") + + # Position bar with exact spacing and characters filled_width = int(base_pct * bar_width) min_pos = int(min_pct * bar_width) max_pos = int(max_pct * bar_width) target_pos = int(target_pct * bar_width) - base_bar = "Base %: [" + + # Build position bar character by character + position_bar = "" for i in range(bar_width): if i == filled_width: - base_bar += "O" # Current position + position_bar += "◆" # Current position elif i == min_pos: - base_bar += "m" # Min threshold + position_bar += "┃" # Min threshold elif i == max_pos: - base_bar += "M" # Max threshold + position_bar += "┃" # Max threshold elif i == target_pos: - base_bar += "T" # Target threshold + position_bar += "┇" # Target threshold elif i < filled_width: - base_bar += "=" + position_bar += "█" # Filled area else: - base_bar += " " - base_bar += f"] {base_pct:.2%}" - status.append(base_bar) - status.append(f"Min: {min_pct:.2%} | Target: {target_pct:.2%} | Max: {max_pct:.2%}") - # Skew visualization - skew = base_pct - target_pct - skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') - max_skew = getattr(self.config, 'max_skew', Decimal('0.0')) - skew_bar_width = 30 - skew_bar = "Skew: " + position_bar += "░" # Empty area + + # Ensure consistent label spacing as seen in screenshot + status.append(f"│ Position: [{position_bar}] │") + + # Skew visualization with exact spacing + skew_bar_width = bar_width center = skew_bar_width // 2 skew_pos = center + int(skew_pct * center * 2) - skew_pos = max(0, min(skew_bar_width, skew_pos)) + skew_pos = max(0, min(skew_bar_width - 1, skew_pos)) + + # Build skew bar character by character + skew_bar = "" for i in range(skew_bar_width): if i == center: - skew_bar += "|" # Center line + skew_bar += "┃" # Center line elif i == skew_pos: - skew_bar += "*" # Current skew + skew_bar += "⬤" # Current skew else: - skew_bar += "-" - skew_bar += f" {skew_pct:+.2%} (max: {max_skew:.2%})" - status.append(skew_bar) - # Active executors summary - status.append("\nActive Executors:") - active_buy = sum(1 for info in self.executors_info if self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.BUY) - active_sell = sum(1 for info in self.executors_info if self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.SELL) - status.append(f"Total: {len(self.executors_info)} (Buy: {active_buy}, Sell: {active_sell})") - # Deviation info + skew_bar += "─" # Empty line + + # Match spacing from screenshot with exact character counts + status.append(f"│ Skew: [{skew_bar}] │") + + # PnL visualization if available + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + take_profit = self.config.global_take_profit + stop_loss = -self.config.global_stop_loss + + pnl_bar_width = bar_width + center = pnl_bar_width // 2 + + # Calculate positions with exact scaling + max_range = max(abs(take_profit), abs(stop_loss), abs(pnl)) * Decimal("1.2") + scale = (pnl_bar_width // 2) / max_range + + pnl_pos = center + int(pnl * scale) + take_profit_pos = center + int(take_profit * scale) + stop_loss_pos = center + int(stop_loss * scale) + + # Ensure positions are within bounds + pnl_pos = max(0, min(pnl_bar_width - 1, pnl_pos)) + take_profit_pos = max(0, min(pnl_bar_width - 1, take_profit_pos)) + stop_loss_pos = max(0, min(pnl_bar_width - 1, stop_loss_pos)) + + # Build PnL bar character by character + pnl_bar = "" + for i in range(pnl_bar_width): + if i == center: + pnl_bar += "│" # Center line + elif i == pnl_pos: + pnl_bar += "⬤" # Current PnL + elif i == take_profit_pos: + pnl_bar += "T" # Take profit line + elif i == stop_loss_pos: + pnl_bar += "S" # Stop loss line + elif (pnl >= 0 and center <= i < pnl_pos) or (pnl < 0 and pnl_pos < i <= center): + pnl_bar += "█" if pnl >= 0 else "▓" + else: + pnl_bar += "─" + + # Match spacing from screenshot + status.append(f"│ PnL: [{pnl_bar}] │") + + # Executors section with precise column widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'EXECUTORS STATUS':<{half_width - 2}} │ {'EXECUTOR VISUALIZATION':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Count active executors by type + active_buy = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.BUY) + active_sell = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.SELL) + total_active = sum(1 for info in self.executors_info if info.is_active) + + # Executor information with fixed formatting + executor_info = [ + f"Total Active: {total_active}", + f"Total Created: {len(self.executors_info)}", + f"Buy Executors: {active_buy}", + f"Sell Executors: {active_sell}" + ] + if 'deviation' in self.processed_data: - deviation = self.processed_data['deviation'] - status.append(f"Deviation: {deviation:.4f}") + executor_info.append(f"Target Deviation: {self.processed_data['deviation']:.4f}") + + # Visualization with consistent block characters for buy/sell representation + buy_bars = "▮" * active_buy if active_buy > 0 else "─" + sell_bars = "▮" * active_sell if active_sell > 0 else "─" + + executor_viz = [ + f"Buy: {buy_bars}", + f"Sell: {sell_bars}" + ] + + # Display with fixed width columns + for exec_line, viz_line in zip_longest(executor_info, executor_viz, fillvalue=""): + status.append(f"│ {exec_line:<{half_width - 2}} │ {viz_line:<{half_width - 2}} │") + + # Bottom border with exact width + status.append(f"╘{'═' * (inner_width)}╛") + return status diff --git a/bots/controllers/generic/pmm_adjusted.py b/bots/controllers/generic/pmm_adjusted.py new file mode 100644 index 00000000..e9bc2667 --- /dev/null +++ b/bots/controllers/generic/pmm_adjusted.py @@ -0,0 +1,669 @@ +from decimal import Decimal +from typing import List, Optional, Tuple, Union + +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors import CloseType + + +class PMMAdjustedConfig(ControllerConfigBase): + """ + This class represents the base configuration for a market making controller. + """ + controller_type: str = "generic" + controller_name: str = "pmm_adjusted" + candles_config: List[CandlesConfig] = [] + connector_name: str = Field( + default="binance", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the name of the connector to use (e.g., binance):", + } + ) + trading_pair: str = Field( + default="BTC-FDUSD", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the trading pair to trade on (e.g., BTC-FDUSD):", + } + ) + candles_connector_name: str = Field(default="binance") + candles_trading_pair: str = Field(default="BTC-USDT") + candles_interval: str = Field(default="1s") + + portfolio_allocation: Decimal = Field( + default=Decimal("0.05"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum quote exposure percentage around mid price (e.g., 0.05 for 5% of total quote allocation):", + } + ) + target_base_pct: Decimal = Field( + default=Decimal("0.2"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the target base percentage (e.g., 0.2 for 20%):", + } + ) + min_base_pct: Decimal = Field( + default=Decimal("0.1"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the minimum base percentage (e.g., 0.1 for 10%):", + } + ) + max_base_pct: Decimal = Field( + default=Decimal("0.4"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum base percentage (e.g., 0.4 for 40%):", + } + ) + buy_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy spreads (e.g., '0.01, 0.02'):", + } + ) + sell_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell spreads (e.g., '0.01, 0.02'):", + } + ) + buy_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + sell_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + executor_refresh_time: int = Field( + default=60 * 5, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the refresh time in seconds for executors (e.g., 300 for 5 minutes):", + } + ) + cooldown_time: int = Field( + default=15, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the cooldown time in seconds between after replacing an executor that traded (e.g., 15):", + } + ) + leverage: int = Field( + default=20, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the leverage to use for trading (e.g., 20 for 20x leverage). Set it to 1 for spot trading:", + } + ) + position_mode: PositionMode = Field(default="HEDGE") + take_profit: Optional[Decimal] = Field( + default=Decimal("0.02"), gt=0, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the take profit as a decimal (e.g., 0.02 for 2%):", + } + ) + take_profit_order_type: Optional[OrderType] = Field( + default="LIMIT_MAKER", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the order type for take profit (e.g., LIMIT_MAKER):", + } + ) + max_skew: Decimal = Field( + default=Decimal("1.0"), + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the maximum skew factor (e.g., 1.0):", + } + ) + global_take_profit: Decimal = Decimal("0.02") + global_stop_loss: Decimal = Decimal("0.05") + + @field_validator("take_profit", mode="before") + @classmethod + def validate_target(cls, v): + if isinstance(v, str): + if v == "": + return None + return Decimal(v) + return v + + @field_validator('take_profit_order_type', mode="before") + @classmethod + def validate_order_type(cls, v) -> OrderType: + if isinstance(v, OrderType): + return v + elif v is None: + return OrderType.MARKET + elif isinstance(v, str): + if v.upper() in OrderType.__members__: + return OrderType[v.upper()] + elif isinstance(v, int): + try: + return OrderType(v) + except ValueError: + pass + raise ValueError(f"Invalid order type: {v}. Valid options are: {', '.join(OrderType.__members__)}") + + @field_validator('buy_spreads', 'sell_spreads', mode="before") + @classmethod + def parse_spreads(cls, v): + if v is None: + return [] + if isinstance(v, str): + if v == "": + return [] + return [float(x.strip()) for x in v.split(',')] + return v + + @field_validator('buy_amounts_pct', 'sell_amounts_pct', mode="before") + @classmethod + def parse_and_validate_amounts(cls, v, validation_info: ValidationInfo): + field_name = validation_info.field_name + if v is None or v == "": + spread_field = field_name.replace('amounts_pct', 'spreads') + return [1 for _ in validation_info.data[spread_field]] + if isinstance(v, str): + return [float(x.strip()) for x in v.split(',')] + elif isinstance(v, list) and len(v) != len(validation_info.data[field_name.replace('amounts_pct', 'spreads')]): + raise ValueError( + f"The number of {field_name} must match the number of {field_name.replace('amounts_pct', 'spreads')}.") + return v + + @field_validator('position_mode', mode="before") + @classmethod + def validate_position_mode(cls, v) -> PositionMode: + if isinstance(v, str): + if v.upper() in PositionMode.__members__: + return PositionMode[v.upper()] + raise ValueError(f"Invalid position mode: {v}. Valid options are: {', '.join(PositionMode.__members__)}") + return v + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + return TripleBarrierConfig( + take_profit=self.take_profit, + trailing_stop=None, + open_order_type=OrderType.LIMIT_MAKER, # Defaulting to LIMIT as is a Maker Controller + take_profit_order_type=self.take_profit_order_type, + stop_loss_order_type=OrderType.MARKET, # Defaulting to MARKET as per requirement + time_limit_order_type=OrderType.MARKET # Defaulting to MARKET as per requirement + ) + + def update_parameters(self, trade_type: TradeType, new_spreads: Union[List[float], str], new_amounts_pct: Optional[Union[List[int], str]] = None): + spreads_field = 'buy_spreads' if trade_type == TradeType.BUY else 'sell_spreads' + amounts_pct_field = 'buy_amounts_pct' if trade_type == TradeType.BUY else 'sell_amounts_pct' + + setattr(self, spreads_field, self.parse_spreads(new_spreads)) + if new_amounts_pct is not None: + setattr(self, amounts_pct_field, self.parse_and_validate_amounts(new_amounts_pct, self.__dict__, self.__fields__[amounts_pct_field])) + else: + setattr(self, amounts_pct_field, [1 for _ in getattr(self, spreads_field)]) + + def get_spreads_and_amounts_in_quote(self, trade_type: TradeType) -> Tuple[List[float], List[float]]: + buy_amounts_pct = getattr(self, 'buy_amounts_pct') + sell_amounts_pct = getattr(self, 'sell_amounts_pct') + + # Calculate total percentages across buys and sells + total_pct = sum(buy_amounts_pct) + sum(sell_amounts_pct) + + # Normalize amounts_pct based on total percentages + if trade_type == TradeType.BUY: + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in buy_amounts_pct] + else: # TradeType.SELL + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in sell_amounts_pct] + + spreads = getattr(self, f'{trade_type.name.lower()}_spreads') + return spreads, [amt_pct * self.total_amount_quote * self.portfolio_allocation for amt_pct in normalized_amounts_pct] + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class PMMAdjusted(ControllerBase): + """ + This class represents the base class for a market making controller. + """ + + def __init__(self, config: PMMAdjustedConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.market_data_provider.initialize_rate_sources([ConnectorPair( + connector_name=config.connector_name, trading_pair=config.trading_pair)]) + self.config.candles_config = [ + CandlesConfig(connector=self.config.candles_connector_name, + trading_pair=self.config.candles_trading_pair, + interval=self.config.candles_interval) + ] + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + Determine actions based on the provided executor handler report. + """ + actions = [] + actions.extend(self.create_actions_proposal()) + actions.extend(self.stop_actions_proposal()) + return actions + + def create_actions_proposal(self) -> List[ExecutorAction]: + """ + Create actions proposal based on the current state of the controller. + """ + create_actions = [] + + # Check if a position reduction executor for TP/SL is already sent + reduction_executor_exists = any( + executor.is_active and + executor.custom_info.get("level_id") == "global_tp_sl" + for executor in self.executors_info + ) + + if (not reduction_executor_exists and + self.processed_data["current_base_pct"] > self.config.target_base_pct and + (self.processed_data["unrealized_pnl_pct"] > self.config.global_take_profit or + self.processed_data["unrealized_pnl_pct"] < -self.config.global_stop_loss)): + + # Create a global take profit or stop loss executor + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=TradeType.SELL, + amount=self.processed_data["position_amount"], + execution_strategy=ExecutionStrategy.MARKET, + price=self.processed_data["reference_price"], + level_id="global_tp_sl" # Use a specific level_id to identify this as a TP/SL executor + ) + )) + return create_actions + levels_to_execute = self.get_levels_to_execute() + # Pre-calculate all spreads and amounts for buy and sell sides + buy_spreads, buy_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.BUY) + sell_spreads, sell_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.SELL) + reference_price = Decimal(self.processed_data["reference_price"]) + # Get current position info for skew calculation + current_pct = self.processed_data["current_base_pct"] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + # Calculate skew factors (0 to 1) - how much to scale orders + if max_pct > min_pct: # Prevent division by zero + # For buys: full size at min_pct, decreasing as we approach max_pct + buy_skew = (max_pct - current_pct) / (max_pct - min_pct) + # For sells: full size at max_pct, decreasing as we approach min_pct + sell_skew = (current_pct - min_pct) / (max_pct - min_pct) + # Ensure values stay between 0.2 and 1.0 (never go below 20% of original size) + buy_skew = max(min(buy_skew, Decimal("1.0")), self.config.max_skew) + sell_skew = max(min(sell_skew, Decimal("1.0")), self.config.max_skew) + else: + buy_skew = sell_skew = Decimal("1.0") + # Create executors for each level + for level_id in levels_to_execute: + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + if trade_type == TradeType.BUY: + spread_in_pct = Decimal(buy_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(buy_amounts_quote[level]) + skew = buy_skew + else: # TradeType.SELL + spread_in_pct = Decimal(sell_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(sell_amounts_quote[level]) + skew = sell_skew + # Calculate price + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + # Calculate amount with skew applied + amount = self.market_data_provider.quantize_order_amount(self.config.connector_name, + self.config.trading_pair, + (amount_quote / price) * skew) + if amount == Decimal("0"): + self.logger().warning(f"The amount of the level {level_id} is 0. Skipping.") + executor_config = self.get_executor_config(level_id, price, amount) + if executor_config is not None: + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + return create_actions + + def get_levels_to_execute(self) -> List[str]: + working_levels = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active or (x.close_type == CloseType.STOP_LOSS and self.market_data_provider.time() - x.close_timestamp < self.config.cooldown_time) + ) + working_levels_ids = [executor.custom_info["level_id"] for executor in working_levels] + return self.get_not_active_levels_ids(working_levels_ids) + + def stop_actions_proposal(self) -> List[ExecutorAction]: + """ + Create a list of actions to stop the executors based on order refresh and early stop conditions. + """ + stop_actions = [] + stop_actions.extend(self.executors_to_refresh()) + stop_actions.extend(self.executors_to_early_stop()) + return stop_actions + + def executors_to_refresh(self) -> List[ExecutorAction]: + executors_to_refresh = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: not x.is_trading and x.is_active and self.market_data_provider.time() - x.timestamp > self.config.executor_refresh_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_refresh] + + def executors_to_early_stop(self) -> List[ExecutorAction]: + """ + Get the executors to early stop based on the current state of market data. This method can be overridden to + implement custom behavior. + """ + executors_to_early_stop = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active and x.is_trading and self.market_data_provider.time() - x.custom_info["open_order_last_update"] > self.config.cooldown_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_early_stop] + + async def update_processed_data(self): + """ + Update the processed data for the controller. This method should be reimplemented to modify the reference price + and spread multiplier based on the market data. By default, it will update the reference price as mid price and + the spread multiplier as 1. + """ + reference_price = self.get_current_candles_price() + position_held = next((position for position in self.positions_held if + (position.trading_pair == self.config.trading_pair) & + (position.connector_name == self.config.connector_name)), None) + target_position = self.config.total_amount_quote * self.config.target_base_pct + if position_held is not None: + position_amount = position_held.amount + current_base_pct = position_held.amount_quote / self.config.total_amount_quote + deviation = (target_position - position_held.amount_quote) / target_position + unrealized_pnl_pct = position_held.unrealized_pnl_quote / position_held.amount_quote if position_held.amount_quote != 0 else Decimal("0") + else: + position_amount = 0 + current_base_pct = 0 + deviation = 1 + unrealized_pnl_pct = 0 + + self.processed_data = {"reference_price": Decimal(reference_price), "spread_multiplier": Decimal("1"), + "deviation": deviation, "current_base_pct": current_base_pct, + "unrealized_pnl_pct": unrealized_pnl_pct, "position_amount": position_amount} + + def get_current_candles_price(self) -> Decimal: + """ + Get the current price from the candles data provider. + """ + candles = self.market_data_provider.get_candles_df(self.config.candles_connector_name, + self.config.candles_trading_pair, + self.config.candles_interval) + if candles is not None and not candles.empty: + last_candle = candles.iloc[-1] + return Decimal(last_candle['close']) + else: + self.logger().warning(f"No candles data available for {self.config.candles_connector_name} - {self.config.candles_trading_pair} at {self.config.candles_interval}. Using last known price.") + return Decimal(self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice)) + + def get_executor_config(self, level_id: str, price: Decimal, amount: Decimal): + """ + Get the executor config for a given level id. + """ + trade_type = self.get_trade_type_from_level_id(level_id) + level_multiplier = self.get_level_from_level_id(level_id) + 1 + return PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + level_id=level_id, + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + entry_price=price, + amount=amount, + triple_barrier_config=self.config.triple_barrier_config.new_instance_with_adjusted_volatility(level_multiplier), + leverage=self.config.leverage, + side=trade_type, + ) + + def get_level_id_from_side(self, trade_type: TradeType, level: int) -> str: + """ + Get the level id based on the trade type and the level. + """ + return f"{trade_type.name.lower()}_{level}" + + def get_trade_type_from_level_id(self, level_id: str) -> TradeType: + return TradeType.BUY if level_id.startswith("buy") else TradeType.SELL + + def get_level_from_level_id(self, level_id: str) -> int: + return int(level_id.split('_')[1]) + + def get_not_active_levels_ids(self, active_levels_ids: List[str]) -> List[str]: + """ + Get the levels to execute based on the current state of the controller. + """ + buy_ids_missing = [self.get_level_id_from_side(TradeType.BUY, level) for level in range(len(self.config.buy_spreads)) + if self.get_level_id_from_side(TradeType.BUY, level) not in active_levels_ids] + sell_ids_missing = [self.get_level_id_from_side(TradeType.SELL, level) for level in range(len(self.config.sell_spreads)) + if self.get_level_id_from_side(TradeType.SELL, level) not in active_levels_ids] + if self.processed_data["current_base_pct"] < self.config.min_base_pct: + return buy_ids_missing + elif self.processed_data["current_base_pct"] > self.config.max_base_pct: + return sell_ids_missing + return buy_ids_missing + sell_ids_missing + + def to_format_status(self) -> List[str]: + """ + Get the status of the controller in a formatted way with ASCII visualizations. + """ + from decimal import Decimal + from itertools import zip_longest + + status = [] + + # Get all required data + base_pct = self.processed_data['current_base_pct'] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + target_pct = self.config.target_base_pct + skew = base_pct - target_pct + skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') + max_skew = getattr(self.config, 'max_skew', Decimal('0.0')) + + # Fixed widths - adjusted based on screenshot analysis + outer_width = 92 # Total width including outer borders + inner_width = outer_width - 4 # Inner content width + half_width = (inner_width) // 2 - 1 # Width of each column in split sections + bar_width = inner_width - 15 # Width of visualization bars (accounting for label) + + # Header - omit ID since it's shown above in controller header + status.append("╒" + "═" * (inner_width) + "╕") + + header_line = ( + f"{self.config.connector_name}:{self.config.trading_pair} " + f"Price: {self.processed_data['reference_price']} " + f"Alloc: {self.config.portfolio_allocation:.1%} " + f"Spread Mult: {self.processed_data['spread_multiplier']} |" + ) + + status.append(f"│ {header_line:<{inner_width}} │") + + # Position and PnL sections with precise widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'POSITION STATUS':<{half_width - 2}} │ {'PROFIT & LOSS':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Position data for left column + position_info = [ + f"Current: {base_pct:.2%}", + f"Target: {target_pct:.2%}", + f"Min/Max: {min_pct:.2%}/{max_pct:.2%}", + f"Skew: {skew_pct:+.2%} (max {max_skew:.2%})" + ] + + # PnL data for right column + pnl_info = [] + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + pnl_sign = "+" if pnl >= 0 else "" + pnl_info = [ + f"Unrealized: {pnl_sign}{pnl:.2%}", + f"Take Profit: {self.config.global_take_profit:.2%}", + f"Stop Loss: {-self.config.global_stop_loss:.2%}", + f"Leverage: {self.config.leverage}x" + ] + + # Display position and PnL info side by side with exact spacing + for pos_line, pnl_line in zip_longest(position_info, pnl_info, fillvalue=""): + status.append(f"│ {pos_line:<{half_width - 2}} │ {pnl_line:<{half_width - 2}} │") + + # Adjust visualization section - ensure consistent spacing + status.append(f"├{'─' * (inner_width)}┤") + status.append(f"│ {'VISUALIZATIONS':<{inner_width}} │") + status.append(f"├{'─' * (inner_width)}┤") + + # Position bar with exact spacing and characters + filled_width = int(base_pct * bar_width) + min_pos = int(min_pct * bar_width) + max_pos = int(max_pct * bar_width) + target_pos = int(target_pct * bar_width) + + # Build position bar character by character + position_bar = "" + for i in range(bar_width): + if i == filled_width: + position_bar += "◆" # Current position + elif i == min_pos: + position_bar += "┃" # Min threshold + elif i == max_pos: + position_bar += "┃" # Max threshold + elif i == target_pos: + position_bar += "┇" # Target threshold + elif i < filled_width: + position_bar += "█" # Filled area + else: + position_bar += "░" # Empty area + + # Ensure consistent label spacing as seen in screenshot + status.append(f"│ Position: [{position_bar}] │") + + # Skew visualization with exact spacing + skew_bar_width = bar_width + center = skew_bar_width // 2 + skew_pos = center + int(skew_pct * center * 2) + skew_pos = max(0, min(skew_bar_width - 1, skew_pos)) + + # Build skew bar character by character + skew_bar = "" + for i in range(skew_bar_width): + if i == center: + skew_bar += "┃" # Center line + elif i == skew_pos: + skew_bar += "⬤" # Current skew + else: + skew_bar += "─" # Empty line + + # Match spacing from screenshot with exact character counts + status.append(f"│ Skew: [{skew_bar}] │") + + # PnL visualization if available + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + take_profit = self.config.global_take_profit + stop_loss = -self.config.global_stop_loss + + pnl_bar_width = bar_width + center = pnl_bar_width // 2 + + # Calculate positions with exact scaling + max_range = max(abs(take_profit), abs(stop_loss), abs(pnl)) * Decimal("1.2") + scale = (pnl_bar_width // 2) / max_range + + pnl_pos = center + int(pnl * scale) + take_profit_pos = center + int(take_profit * scale) + stop_loss_pos = center + int(stop_loss * scale) + + # Ensure positions are within bounds + pnl_pos = max(0, min(pnl_bar_width - 1, pnl_pos)) + take_profit_pos = max(0, min(pnl_bar_width - 1, take_profit_pos)) + stop_loss_pos = max(0, min(pnl_bar_width - 1, stop_loss_pos)) + + # Build PnL bar character by character + pnl_bar = "" + for i in range(pnl_bar_width): + if i == center: + pnl_bar += "│" # Center line + elif i == pnl_pos: + pnl_bar += "⬤" # Current PnL + elif i == take_profit_pos: + pnl_bar += "T" # Take profit line + elif i == stop_loss_pos: + pnl_bar += "S" # Stop loss line + elif (pnl >= 0 and center <= i < pnl_pos) or (pnl < 0 and pnl_pos < i <= center): + pnl_bar += "█" if pnl >= 0 else "▓" + else: + pnl_bar += "─" + + # Match spacing from screenshot + status.append(f"│ PnL: [{pnl_bar}] │") + + # Executors section with precise column widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'EXECUTORS STATUS':<{half_width - 2}} │ {'EXECUTOR VISUALIZATION':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Count active executors by type + active_buy = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.BUY) + active_sell = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.SELL) + total_active = sum(1 for info in self.executors_info if info.is_active) + + # Executor information with fixed formatting + executor_info = [ + f"Total Active: {total_active}", + f"Total Created: {len(self.executors_info)}", + f"Buy Executors: {active_buy}", + f"Sell Executors: {active_sell}" + ] + + if 'deviation' in self.processed_data: + executor_info.append(f"Target Deviation: {self.processed_data['deviation']:.4f}") + + # Visualization with consistent block characters for buy/sell representation + buy_bars = "▮" * active_buy if active_buy > 0 else "─" + sell_bars = "▮" * active_sell if active_sell > 0 else "─" + + executor_viz = [ + f"Buy: {buy_bars}", + f"Sell: {sell_bars}" + ] + + # Display with fixed width columns + for exec_line, viz_line in zip_longest(executor_info, executor_viz, fillvalue=""): + status.append(f"│ {exec_line:<{half_width - 2}} │ {viz_line:<{half_width - 2}} │") + + # Bottom border with exact width + status.append(f"╘{'═' * (inner_width)}╛") + + return status diff --git a/bots/controllers/generic/stat_arb.py b/bots/controllers/generic/stat_arb.py new file mode 100644 index 00000000..527db07a --- /dev/null +++ b/bots/controllers/generic/stat_arb.py @@ -0,0 +1,475 @@ +from decimal import Decimal +from typing import List + +import numpy as np +from sklearn.linear_model import LinearRegression + +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair, PositionSummary +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction + + +class StatArbConfig(ControllerConfigBase): + """ + Configuration for a statistical arbitrage controller that trades two cointegrated assets. + """ + controller_type: str = "generic" + controller_name: str = "stat_arb" + candles_config: List[CandlesConfig] = [] + connector_pair_dominant: ConnectorPair = ConnectorPair(connector_name="binance_perpetual", trading_pair="SOL-USDT") + connector_pair_hedge: ConnectorPair = ConnectorPair(connector_name="binance_perpetual", trading_pair="POPCAT-USDT") + interval: str = "1m" + lookback_period: int = 300 + entry_threshold: Decimal = Decimal("2.0") + take_profit: Decimal = Decimal("0.0008") + tp_global: Decimal = Decimal("0.01") + sl_global: Decimal = Decimal("0.05") + min_amount_quote: Decimal = Decimal("10") + quoter_spread: Decimal = Decimal("0.0001") + quoter_cooldown: int = 30 + quoter_refresh: int = 10 + max_orders_placed_per_side: int = 2 + max_orders_filled_per_side: int = 2 + max_position_deviation: Decimal = Decimal("0.1") + pos_hedge_ratio: Decimal = Decimal("1.0") + leverage: int = 20 + position_mode: PositionMode = PositionMode.HEDGE + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + return TripleBarrierConfig( + take_profit=self.take_profit, + open_order_type=OrderType.LIMIT_MAKER, + take_profit_order_type=OrderType.LIMIT_MAKER, + ) + + def update_markets(self, markets: dict) -> dict: + """Update markets dictionary with both trading pairs""" + # Add dominant pair + if self.connector_pair_dominant.connector_name not in markets: + markets[self.connector_pair_dominant.connector_name] = set() + markets[self.connector_pair_dominant.connector_name].add(self.connector_pair_dominant.trading_pair) + + # Add hedge pair + if self.connector_pair_hedge.connector_name not in markets: + markets[self.connector_pair_hedge.connector_name] = set() + markets[self.connector_pair_hedge.connector_name].add(self.connector_pair_hedge.trading_pair) + + return markets + + +class StatArb(ControllerBase): + """ + Statistical arbitrage controller that trades two cointegrated assets. + """ + + def __init__(self, config: StatArbConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.theoretical_dominant_quote = self.config.total_amount_quote * (1 / (1 + self.config.pos_hedge_ratio)) + self.theoretical_hedge_quote = self.config.total_amount_quote * (self.config.pos_hedge_ratio / (1 + self.config.pos_hedge_ratio)) + + # Initialize processed data dictionary + self.processed_data = { + "dominant_price": None, + "hedge_price": None, + "spread": None, + "z_score": None, + "hedge_ratio": None, + "position_dominant": Decimal("0"), + "position_hedge": Decimal("0"), + "active_orders_dominant": [], + "active_orders_hedge": [], + "pair_pnl": Decimal("0"), + "signal": 0 # 0: no signal, 1: long dominant/short hedge, -1: short dominant/long hedge + } + + # Setup candles config if not already set + if len(self.config.candles_config) == 0: + max_records = self.config.lookback_period + 20 # extra records for safety + self.max_records = max_records + self.config.candles_config = [ + CandlesConfig( + connector=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + interval=self.config.interval, + max_records=max_records + ), + CandlesConfig( + connector=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + interval=self.config.interval, + max_records=max_records + ) + ] + if "_perpetual" in self.config.connector_pair_dominant.connector_name: + connector = self.market_data_provider.get_connector(self.config.connector_pair_dominant.connector_name) + connector.set_position_mode(self.config.position_mode) + connector.set_leverage(self.config.connector_pair_dominant.trading_pair, self.config.leverage) + if "_perpetual" in self.config.connector_pair_hedge.connector_name: + connector = self.market_data_provider.get_connector(self.config.connector_pair_hedge.connector_name) + connector.set_position_mode(self.config.position_mode) + connector.set_leverage(self.config.connector_pair_hedge.trading_pair, self.config.leverage) + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + The execution logic for the statistical arbitrage strategy. + Market Data Conditions: Signal is generated based on the z-score of the spread between the two assets. + If signal == 1 --> long dominant/short hedge + If signal == -1 --> short dominant/long hedge + Execution Conditions: If the signal is generated add position executors to quote from the dominant and hedge markets. + We compare the current position with the theoretical position for the dominant and hedge assets. + If the current position + the active placed amount is greater than the theoretical position, can't place more orders. + If the imbalance scaled pct is greater than the threshold, we avoid placing orders in the market passed on filtered_connector_pair. + If the pnl of total position is greater than the take profit or lower than the stop loss, we close the position. + """ + actions: List[ExecutorAction] = [] + # Check global take profit and stop loss + if self.processed_data["pair_pnl_pct"] > self.config.tp_global or self.processed_data["pair_pnl_pct"] < -self.config.sl_global: + # Close all positions + for position in self.positions_held: + actions.extend(self.get_executors_to_reduce_position(position)) + return actions + # Check the signal + elif self.processed_data["signal"] != 0: + actions.extend(self.get_executors_to_quote()) + actions.extend(self.get_executors_to_reduce_position_on_opposite_signal()) + + # Get the executors to keep position after a cooldown is reached + actions.extend(self.get_executors_to_keep_position()) + actions.extend(self.get_executors_to_refresh()) + + return actions + + def get_executors_to_reduce_position_on_opposite_signal(self) -> List[ExecutorAction]: + if self.processed_data["signal"] == 1: + dominant_side, hedge_side = TradeType.SELL, TradeType.BUY + elif self.processed_data["signal"] == -1: + dominant_side, hedge_side = TradeType.BUY, TradeType.SELL + else: + return [] + # Get executors to stop + dominant_active_executors_to_stop = self.filter_executors(self.executors_info, filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.side == dominant_side) + hedge_active_executors_to_stop = self.filter_executors(self.executors_info, filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.side == hedge_side) + stop_actions = [StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=False) for executor in dominant_active_executors_to_stop + hedge_active_executors_to_stop] + + # Get order executors to reduce positions + reduce_actions: List[ExecutorAction] = [] + for position in self.positions_held: + if position.connector_name == self.config.connector_pair_dominant.connector_name and position.trading_pair == self.config.connector_pair_dominant.trading_pair and position.side == dominant_side: + reduce_actions.extend(self.get_executors_to_reduce_position(position)) + elif position.connector_name == self.config.connector_pair_hedge.connector_name and position.trading_pair == self.config.connector_pair_hedge.trading_pair and position.side == hedge_side: + reduce_actions.extend(self.get_executors_to_reduce_position(position)) + return stop_actions + reduce_actions + + def get_executors_to_keep_position(self) -> List[ExecutorAction]: + stop_actions: List[ExecutorAction] = [] + for executor in self.processed_data["executors_dominant_filled"] + self.processed_data["executors_hedge_filled"]: + if self.market_data_provider.time() - executor.timestamp >= self.config.quoter_cooldown: + # Create a new executor to keep the position + stop_actions.append(StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=True)) + return stop_actions + + def get_executors_to_refresh(self) -> List[ExecutorAction]: + refresh_actions: List[ExecutorAction] = [] + for executor in self.processed_data["executors_dominant_placed"] + self.processed_data["executors_hedge_placed"]: + if self.market_data_provider.time() - executor.timestamp >= self.config.quoter_refresh: + # Create a new executor to refresh the position + refresh_actions.append(StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=False)) + return refresh_actions + + def get_executors_to_quote(self) -> List[ExecutorAction]: + """ + Get Order Executor to quote from the dominant and hedge markets. + """ + actions: List[ExecutorAction] = [] + trade_type_dominant = TradeType.BUY if self.processed_data["signal"] == 1 else TradeType.SELL + trade_type_hedge = TradeType.SELL if self.processed_data["signal"] == 1 else TradeType.BUY + + # Analyze dominant active orders, max deviation and imbalance to create a new executor + if self.processed_data["dominant_gap"] > Decimal("0") and \ + self.processed_data["filter_connector_pair"] != self.config.connector_pair_dominant and \ + len(self.processed_data["executors_dominant_placed"]) < self.config.max_orders_placed_per_side and \ + len(self.processed_data["executors_dominant_filled"]) < self.config.max_orders_filled_per_side: + # Create Position Executor for dominant asset + if trade_type_dominant == TradeType.BUY: + price = self.processed_data["min_price_dominant"] * (1 - self.config.quoter_spread) + else: + price = self.processed_data["max_price_dominant"] * (1 + self.config.quoter_spread) + dominant_executor_config = PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + side=trade_type_dominant, + entry_price=price, + amount=self.config.min_amount_quote / self.processed_data["dominant_price"], + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + ) + actions.append(CreateExecutorAction(controller_id=self.config.id, executor_config=dominant_executor_config)) + + # Analyze hedge active orders, max deviation and imbalance to create a new executor + if self.processed_data["hedge_gap"] > Decimal("0") and \ + self.processed_data["filter_connector_pair"] != self.config.connector_pair_hedge and \ + len(self.processed_data["executors_hedge_placed"]) < self.config.max_orders_placed_per_side and \ + len(self.processed_data["executors_hedge_filled"]) < self.config.max_orders_filled_per_side: + # Create Position Executor for hedge asset + if trade_type_hedge == TradeType.BUY: + price = self.processed_data["min_price_hedge"] * (1 - self.config.quoter_spread) + else: + price = self.processed_data["max_price_hedge"] * (1 + self.config.quoter_spread) + hedge_executor_config = PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + side=trade_type_hedge, + entry_price=price, + amount=self.config.min_amount_quote / self.processed_data["hedge_price"], + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + ) + actions.append(CreateExecutorAction(controller_id=self.config.id, executor_config=hedge_executor_config)) + return actions + + def get_executors_to_reduce_position(self, position: PositionSummary) -> List[ExecutorAction]: + """ + Get Order Executor to reduce position. + """ + if position.amount > Decimal("0"): + # Close position + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=position.connector_name, + trading_pair=position.trading_pair, + side=TradeType.BUY if position.side == TradeType.SELL else TradeType.SELL, + amount=position.amount, + position_action=PositionAction.CLOSE, + execution_strategy=ExecutionStrategy.MARKET, + leverage=self.config.leverage, + ) + return [CreateExecutorAction(controller_id=self.config.id, executor_config=config)] + return [] + + async def update_processed_data(self): + """ + Update processed data with the latest market information and statistical calculations + needed for the statistical arbitrage strategy. + """ + # Stat arb analysis + spread, z_score = self.get_spread_and_z_score() + + # Generate trading signal based on z-score + entry_threshold = float(self.config.entry_threshold) + if z_score > entry_threshold: + # Spread is too high, expect it to revert: long dominant, short hedge + signal = 1 + dominant_side, hedge_side = TradeType.BUY, TradeType.SELL + elif z_score < -entry_threshold: + # Spread is too low, expect it to revert: short dominant, long hedge + signal = -1 + dominant_side, hedge_side = TradeType.SELL, TradeType.BUY + else: + # No signal + signal = 0 + dominant_side, hedge_side = None, None + + # Current prices + dominant_price, hedge_price = self.get_pairs_prices() + + # Get current positions stats by signal + positions_dominant = next((position for position in self.positions_held if position.connector_name == self.config.connector_pair_dominant.connector_name and position.trading_pair == self.config.connector_pair_dominant.trading_pair and (position.side == dominant_side or dominant_side is None)), None) + positions_hedge = next((position for position in self.positions_held if position.connector_name == self.config.connector_pair_hedge.connector_name and position.trading_pair == self.config.connector_pair_hedge.trading_pair and (position.side == hedge_side or hedge_side is None)), None) + # Get position stats + position_dominant_quote = positions_dominant.amount_quote if positions_dominant else Decimal("0") + position_hedge_quote = positions_hedge.amount_quote if positions_hedge else Decimal("0") + position_dominant_pnl_quote = positions_dominant.global_pnl_quote if positions_dominant else Decimal("0") + position_hedge_pnl_quote = positions_hedge.global_pnl_quote if positions_hedge else Decimal("0") + pair_pnl_pct = (position_dominant_pnl_quote + position_hedge_pnl_quote) / (position_dominant_quote + position_hedge_quote) if (position_dominant_quote + position_hedge_quote) != 0 else Decimal("0") + # Get active executors + executors_dominant_placed, executors_dominant_filled = self.get_executors_dominant() + executors_hedge_placed, executors_hedge_filled = self.get_executors_hedge() + min_price_dominant = Decimal(str(min([executor.config.entry_price for executor in executors_dominant_placed]))) if executors_dominant_placed else None + max_price_dominant = Decimal(str(max([executor.config.entry_price for executor in executors_dominant_placed]))) if executors_dominant_placed else None + min_price_hedge = Decimal(str(min([executor.config.entry_price for executor in executors_hedge_placed]))) if executors_hedge_placed else None + max_price_hedge = Decimal(str(max([executor.config.entry_price for executor in executors_hedge_placed]))) if executors_hedge_placed else None + + active_amount_dominant = Decimal(str(sum([executor.filled_amount_quote for executor in executors_dominant_filled]))) + active_amount_hedge = Decimal(str(sum([executor.filled_amount_quote for executor in executors_hedge_filled]))) + + # Compute imbalance based on the hedge ratio + dominant_gap = self.theoretical_dominant_quote - position_dominant_quote - active_amount_dominant + hedge_gap = self.theoretical_hedge_quote - position_hedge_quote - active_amount_hedge + imbalance = position_dominant_quote - position_hedge_quote + imbalance_scaled = position_dominant_quote - position_hedge_quote * self.config.pos_hedge_ratio + imbalance_scaled_pct = imbalance_scaled / position_dominant_quote if position_dominant_quote != Decimal("0") else Decimal("0") + filter_connector_pair = None + if imbalance_scaled_pct > self.config.max_position_deviation: + # Avoid placing orders in the dominant market + filter_connector_pair = self.config.connector_pair_dominant + elif imbalance_scaled_pct < -self.config.max_position_deviation: + # Avoid placing orders in the hedge market + filter_connector_pair = self.config.connector_pair_hedge + + # Update processed data + self.processed_data.update({ + "dominant_price": Decimal(str(dominant_price)), + "hedge_price": Decimal(str(hedge_price)), + "spread": Decimal(str(spread)), + "z_score": Decimal(str(z_score)), + "dominant_gap": Decimal(str(dominant_gap)), + "hedge_gap": Decimal(str(hedge_gap)), + "position_dominant_quote": position_dominant_quote, + "position_hedge_quote": position_hedge_quote, + "active_amount_dominant": active_amount_dominant, + "active_amount_hedge": active_amount_hedge, + "signal": signal, + # Store full dataframes for reference + "imbalance": Decimal(str(imbalance)), + "imbalance_scaled_pct": Decimal(str(imbalance_scaled_pct)), + "filter_connector_pair": filter_connector_pair, + "min_price_dominant": min_price_dominant if min_price_dominant is not None else Decimal(str(dominant_price)), + "max_price_dominant": max_price_dominant if max_price_dominant is not None else Decimal(str(dominant_price)), + "min_price_hedge": min_price_hedge if min_price_hedge is not None else Decimal(str(hedge_price)), + "max_price_hedge": max_price_hedge if max_price_hedge is not None else Decimal(str(hedge_price)), + "executors_dominant_filled": executors_dominant_filled, + "executors_hedge_filled": executors_hedge_filled, + "executors_dominant_placed": executors_dominant_placed, + "executors_hedge_placed": executors_hedge_placed, + "pair_pnl_pct": pair_pnl_pct, + }) + + def get_spread_and_z_score(self): + # Fetch candle data for both assets + dominant_df = self.market_data_provider.get_candles_df( + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + interval=self.config.interval, + max_records=self.max_records + ) + + hedge_df = self.market_data_provider.get_candles_df( + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + interval=self.config.interval, + max_records=self.max_records + ) + + if dominant_df.empty or hedge_df.empty: + self.logger().warning("Not enough candle data available for statistical analysis") + return + + # Extract close prices + dominant_prices = dominant_df['close'].values + hedge_prices = hedge_df['close'].values + + # Ensure we have enough data and both series have the same length + min_length = min(len(dominant_prices), len(hedge_prices)) + if min_length < self.config.lookback_period: + self.logger().warning( + f"Not enough data points for analysis. Required: {self.config.lookback_period}, Available: {min_length}") + return + + # Use the most recent data points + dominant_prices = dominant_prices[-self.config.lookback_period:] + hedge_prices = hedge_prices[-self.config.lookback_period:] + + # Convert to numpy arrays + dominant_prices_np = np.array(dominant_prices, dtype=float) + hedge_prices_np = np.array(hedge_prices, dtype=float) + + # Calculate percentage returns + dominant_pct_change = np.diff(dominant_prices_np) / dominant_prices_np[:-1] + hedge_pct_change = np.diff(hedge_prices_np) / hedge_prices_np[:-1] + + # Convert to cumulative returns + dominant_cum_returns = np.cumprod(dominant_pct_change + 1) + hedge_cum_returns = np.cumprod(hedge_pct_change + 1) + + # Normalize to start at 1 + dominant_cum_returns = dominant_cum_returns / dominant_cum_returns[0] if len(dominant_cum_returns) > 0 else np.array([1.0]) + hedge_cum_returns = hedge_cum_returns / hedge_cum_returns[0] if len(hedge_cum_returns) > 0 else np.array([1.0]) + + # Perform linear regression + dominant_cum_returns_reshaped = dominant_cum_returns.reshape(-1, 1) + reg = LinearRegression().fit(dominant_cum_returns_reshaped, hedge_cum_returns) + alpha = reg.intercept_ + beta = reg.coef_[0] + self.processed_data.update({ + "alpha": alpha, + "beta": beta, + }) + + # Calculate spread as percentage difference from predicted value + y_pred = alpha + beta * dominant_cum_returns + spread_pct = (hedge_cum_returns - y_pred) / y_pred * 100 + + # Calculate z-score + mean_spread = np.mean(spread_pct) + std_spread = np.std(spread_pct) + if std_spread == 0: + self.logger().warning("Standard deviation of spread is zero, cannot calculate z-score") + return + + current_spread = spread_pct[-1] + current_z_score = (current_spread - mean_spread) / std_spread + + return current_spread, current_z_score + + def get_pairs_prices(self): + current_dominant_price = self.market_data_provider.get_price_by_type( + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, price_type=PriceType.MidPrice) + + current_hedge_price = self.market_data_provider.get_price_by_type( + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, price_type=PriceType.MidPrice) + return current_dominant_price, current_hedge_price + + def get_executors_dominant(self): + active_executors_dominant_placed = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.is_active and not e.is_trading and e.type == "position_executor" + ) + active_executors_dominant_filled = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.is_active and e.is_trading and e.type == "position_executor" + ) + return active_executors_dominant_placed, active_executors_dominant_filled + + def get_executors_hedge(self): + active_executors_hedge_placed = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.is_active and not e.is_trading and e.type == "position_executor" + ) + active_executors_hedge_filled = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.is_active and e.is_trading and e.type == "position_executor" + ) + return active_executors_hedge_placed, active_executors_hedge_filled + + def to_format_status(self) -> List[str]: + """ + Format the status of the controller for display. + """ + status_lines = [] + status_lines.append(f""" +Dominant Pair: {self.config.connector_pair_dominant} | Hedge Pair: {self.config.connector_pair_hedge} | +Timeframe: {self.config.interval} | Lookback Period: {self.config.lookback_period} | Entry Threshold: {self.config.entry_threshold} + +Positions targets: +Theoretical Dominant : {self.theoretical_dominant_quote} | Theoretical Hedge: {self.theoretical_hedge_quote} | Position Hedge Ratio: {self.config.pos_hedge_ratio} +Position Dominant : {self.processed_data['position_dominant_quote']:.2f} | Position Hedge: {self.processed_data['position_hedge_quote']:.2f} | Imbalance: {self.processed_data['imbalance']:.2f} | Imbalance Scaled: {self.processed_data['imbalance_scaled_pct']:.2f} % + +Current Executors: +Active Orders Dominant : {len(self.processed_data['executors_dominant_placed'])} | Active Orders Hedge : {len(self.processed_data['executors_hedge_placed'])} | +Active Orders Dominant Filled: {len(self.processed_data['executors_dominant_filled'])} | Active Orders Hedge Filled: {len(self.processed_data['executors_hedge_filled'])} + +Signal: {self.processed_data['signal']:.2f} | Z-Score: {self.processed_data['z_score']:.2f} | Spread: {self.processed_data['spread']:.2f} +Alpha : {self.processed_data['alpha']:.2f} | Beta: {self.processed_data['beta']:.2f} +Pair PnL PCT: {self.processed_data['pair_pnl_pct'] * 100:.2f} % +""") + return status_lines diff --git a/bots/controllers/market_making/dman_maker_v2.py b/bots/controllers/market_making/dman_maker_v2.py index 6cba442e..2002fddd 100644 --- a/bots/controllers/market_making/dman_maker_v2.py +++ b/bots/controllers/market_making/dman_maker_v2.py @@ -2,6 +2,8 @@ from typing import List, Optional import pandas_ta as ta # noqa: F401 +from pydantic import Field, field_validator + from hummingbot.core.data_type.common import TradeType from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.market_making_controller_base import ( @@ -10,7 +12,6 @@ ) from hummingbot.strategy_v2.executors.dca_executor.data_types import DCAExecutorConfig, DCAMode from hummingbot.strategy_v2.models.executor_actions import ExecutorAction, StopExecutorAction -from pydantic import Field, field_validator class DManMakerV2Config(MarketMakingControllerConfigBase): diff --git a/bots/credentials/master_account/.password_verification b/bots/credentials/master_account/.password_verification deleted file mode 100644 index b8c76184..00000000 --- a/bots/credentials/master_account/.password_verification +++ /dev/null @@ -1 +0,0 @@ -7b2263727970746f223a207b22636970686572223a20226165732d3132382d637472222c2022636970686572706172616d73223a207b226976223a20223864336365306436393461623131396334363135663935366464653839363063227d2c202263697068657274657874223a20223836333266323430613563306131623665353664222c20226b6466223a202270626b646632222c20226b6466706172616d73223a207b2263223a20313030303030302c2022646b6c656e223a2033322c2022707266223a2022686d61632d736861323536222c202273616c74223a20226566373330376531636464373964376132303338323534656139343433663930227d2c20226d6163223a202266393439383534613530633138363633386363353962336133363665633962353333386633613964373266636635343066313034333361353431636232306438227d2c202276657273696f6e223a20337d \ No newline at end of file diff --git a/bots/credentials/master_account/conf_client.yml b/bots/credentials/master_account/conf_client.yml index 2ad547af..7093ae65 100644 --- a/bots/credentials/master_account/conf_client.yml +++ b/bots/credentials/master_account/conf_client.yml @@ -42,9 +42,6 @@ mqtt_bridge: # Error log sharing send_error_logs: true -# Can store the previous strategy ran for quick retrieval. -previous_strategy: some-strategy.yml - # Advanced database options, currently supports SQLAlchemy's included dialects # Reference: https://docs.sqlalchemy.org/en/13/dialects/ # To use an instance of SQLite DB the required configuration is @@ -114,19 +111,6 @@ certs_path: /Users/dardonacci/Documents/work/hummingbot/certs anonymized_metrics_mode: anonymized_metrics_interval_min: 15.0 -# Command Shortcuts -# Define abbreviations for often used commands -# or batch grouped commands together -command_shortcuts: -- command: spreads - help: Set bid and ask spread - arguments: - - Bid Spread - - Ask Spread - output: - - config bid_spread $1 - - config ask_spread $2 - # A source for rate oracle, currently ascend_ex, binance, coin_gecko, coin_cap, kucoin, gate_io rate_oracle_source: name: binance diff --git a/bots/scripts/v2_with_controllers.py b/bots/scripts/v2_with_controllers.py index c62d585c..0dd0d8ce 100644 --- a/bots/scripts/v2_with_controllers.py +++ b/bots/scripts/v2_with_controllers.py @@ -1,33 +1,26 @@ import os -import time from decimal import Decimal from typing import Dict, List, Optional, Set from hummingbot.client.hummingbot_application import HummingbotApplication from hummingbot.connector.connector_base import ConnectorBase -from hummingbot.core.clock import Clock -from hummingbot.core.data_type.common import OrderType, TradeType + +from hummingbot.core.event.events import MarketOrderFailureEvent from hummingbot.data_feed.candles_feed.data_types import CandlesConfig -from hummingbot.remote_iface.mqtt import ETopicPublisher from hummingbot.strategy.strategy_v2_base import StrategyV2Base, StrategyV2ConfigBase from hummingbot.strategy_v2.models.base import RunnableStatus from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, StopExecutorAction -class GenericV2StrategyWithCashOutConfig(StrategyV2ConfigBase): +class V2WithControllersConfig(StrategyV2ConfigBase): script_file_name: str = os.path.basename(__file__) candles_config: List[CandlesConfig] = [] markets: Dict[str, Set[str]] = {} - time_to_cash_out: Optional[int] = None - max_global_drawdown: Optional[float] = None - max_controller_drawdown: Optional[float] = None - rebalance_interval: Optional[int] = None - extra_inventory: Optional[float] = 0.02 - min_amount_to_rebalance_usd: Decimal = Decimal("8") - asset_to_rebalance: str = "USDT" + max_global_drawdown_quote: Optional[float] = None + max_controller_drawdown_quote: Optional[float] = None -class GenericV2StrategyWithCashOut(StrategyV2Base): +class V2WithControllers(StrategyV2Base): """ This script runs a generic strategy with cash out feature. Will also check if the controllers configs have been updated and apply the new settings. @@ -40,131 +33,43 @@ class GenericV2StrategyWithCashOut(StrategyV2Base): """ performance_report_interval: int = 1 - def __init__(self, connectors: Dict[str, ConnectorBase], config: GenericV2StrategyWithCashOutConfig): + def __init__(self, connectors: Dict[str, ConnectorBase], config: V2WithControllersConfig): super().__init__(connectors, config) self.config = config - self.cashing_out = False self.max_pnl_by_controller = {} - self.performance_reports = {} self.max_global_pnl = Decimal("0") self.drawdown_exited_controllers = [] self.closed_executors_buffer: int = 30 - self.rebalance_interval: int = self.config.rebalance_interval self._last_performance_report_timestamp = 0 - self._last_rebalance_check_timestamp = 0 - hb_app = HummingbotApplication.main_application() - self.mqtt_enabled = hb_app._mqtt is not None - self._pub: Optional[ETopicPublisher] = None - if self.config.time_to_cash_out: - self.cash_out_time = self.config.time_to_cash_out + time.time() - else: - self.cash_out_time = None - - def start(self, clock: Clock, timestamp: float) -> None: - """ - Start the strategy. - :param clock: Clock to use. - :param timestamp: Current time. - """ - self._last_timestamp = timestamp - self.apply_initial_setting() - if self.mqtt_enabled: - self._pub = ETopicPublisher("performance", use_bot_prefix=True) - - async def on_stop(self): - await super().on_stop() - if self.mqtt_enabled: - self._pub({controller_id: {} for controller_id in self.controllers.keys()}) - self._pub = None def on_tick(self): super().on_tick() - self.performance_reports = {controller_id: self.executor_orchestrator.generate_performance_report(controller_id=controller_id).dict() for controller_id in self.controllers.keys()} - self.control_rebalance() - self.control_cash_out() - self.control_max_drawdown() - self.send_performance_report() - - def control_rebalance(self): - if self.rebalance_interval and self._last_rebalance_check_timestamp + self.rebalance_interval <= self.current_timestamp: - balance_required = {} - for controller_id, controller in self.controllers.items(): - connector_name = controller.config.dict().get("connector_name") - if connector_name and "perpetual" in connector_name: - continue - if connector_name not in balance_required: - balance_required[connector_name] = {} - tokens_required = controller.get_balance_requirements() - for token, amount in tokens_required: - if token not in balance_required[connector_name]: - balance_required[connector_name][token] = amount - else: - balance_required[connector_name][token] += amount - for connector_name, balance_requirements in balance_required.items(): - connector = self.connectors[connector_name] - for token, amount in balance_requirements.items(): - if token == self.config.asset_to_rebalance: - continue - balance = connector.get_balance(token) - trading_pair = f"{token}-{self.config.asset_to_rebalance}" - mid_price = connector.get_mid_price(trading_pair) - trading_rule = connector.trading_rules[trading_pair] - amount_with_safe_margin = amount * (1 + Decimal(self.config.extra_inventory)) - active_executors_for_pair = self.filter_executors( - executors=self.get_all_executors(), - filter_func=lambda x: x.is_active and x.trading_pair == trading_pair and x.connector_name == connector_name - ) - unmatched_amount = sum([executor.filled_amount_quote for executor in active_executors_for_pair if executor.side == TradeType.SELL]) - sum([executor.filled_amount_quote for executor in active_executors_for_pair if executor.side == TradeType.BUY]) - balance += unmatched_amount / mid_price - base_balance_diff = balance - amount_with_safe_margin - abs_balance_diff = abs(base_balance_diff) - trading_rules_condition = abs_balance_diff > trading_rule.min_order_size and abs_balance_diff * mid_price > trading_rule.min_notional_size and abs_balance_diff * mid_price > self.config.min_amount_to_rebalance_usd - order_type = OrderType.MARKET - if base_balance_diff > 0: - if trading_rules_condition: - self.logger().info(f"Rebalance: Selling {amount_with_safe_margin} {token} to {self.config.asset_to_rebalance}. Balance: {balance} | Executors unmatched balance {unmatched_amount / mid_price}") - connector.sell( - trading_pair=trading_pair, - amount=abs_balance_diff, - order_type=order_type, - price=mid_price) - else: - self.logger().info("Skipping rebalance due a low amount to sell that may cause future imbalance") - else: - if not trading_rules_condition: - amount = max([self.config.min_amount_to_rebalance_usd / mid_price, trading_rule.min_order_size, trading_rule.min_notional_size / mid_price]) - self.logger().info(f"Rebalance: Buying for a higher value to avoid future imbalance {amount} {token} to {self.config.asset_to_rebalance}. Balance: {balance} | Executors unmatched balance {unmatched_amount}") - else: - amount = abs_balance_diff - self.logger().info(f"Rebalance: Buying {amount} {token} to {self.config.asset_to_rebalance}. Balance: {balance} | Executors unmatched balance {unmatched_amount}") - connector.buy( - trading_pair=trading_pair, - amount=amount, - order_type=order_type, - price=mid_price) - self._last_rebalance_check_timestamp = self.current_timestamp + if not self._is_stop_triggered: + self.check_manual_kill_switch() + self.control_max_drawdown() + self.send_performance_report() def control_max_drawdown(self): - if self.config.max_controller_drawdown: + if self.config.max_controller_drawdown_quote: self.check_max_controller_drawdown() - if self.config.max_global_drawdown: + if self.config.max_global_drawdown_quote: self.check_max_global_drawdown() def check_max_controller_drawdown(self): for controller_id, controller in self.controllers.items(): if controller.status != RunnableStatus.RUNNING: continue - controller_pnl = self.performance_reports[controller_id]["global_pnl_quote"] + controller_pnl = self.get_performance_report(controller_id).global_pnl_quote last_max_pnl = self.max_pnl_by_controller[controller_id] if controller_pnl > last_max_pnl: self.max_pnl_by_controller[controller_id] = controller_pnl else: current_drawdown = last_max_pnl - controller_pnl - if current_drawdown > self.config.max_controller_drawdown: + if current_drawdown > self.config.max_controller_drawdown_quote: self.logger().info(f"Controller {controller_id} reached max drawdown. Stopping the controller.") controller.stop() executors_order_placed = self.filter_executors( - executors=self.executors_info[controller_id], + executors=self.get_executors_by_controller(controller_id), filter_func=lambda x: x.is_active and not x.is_trading, ) self.executor_orchestrator.execute_actions( @@ -173,38 +78,24 @@ def check_max_controller_drawdown(self): self.drawdown_exited_controllers.append(controller_id) def check_max_global_drawdown(self): - current_global_pnl = sum([report["global_pnl_quote"] for report in self.performance_reports.values()]) + current_global_pnl = sum([self.get_performance_report(controller_id).global_pnl_quote for controller_id in self.controllers.keys()]) if current_global_pnl > self.max_global_pnl: self.max_global_pnl = current_global_pnl else: current_global_drawdown = self.max_global_pnl - current_global_pnl - if current_global_drawdown > self.config.max_global_drawdown: + if current_global_drawdown > self.config.max_global_drawdown_quote: self.drawdown_exited_controllers.extend(list(self.controllers.keys())) self.logger().info("Global drawdown reached. Stopping the strategy.") + self._is_stop_triggered = True HummingbotApplication.main_application().stop() def send_performance_report(self): - if self.current_timestamp - self._last_performance_report_timestamp >= self.performance_report_interval and self.mqtt_enabled: - self._pub(self.performance_reports) + if self.current_timestamp - self._last_performance_report_timestamp >= self.performance_report_interval and self._pub: + performance_reports = {controller_id: self.get_performance_report(controller_id).dict() for controller_id in self.controllers.keys()} + self._pub(performance_reports) self._last_performance_report_timestamp = self.current_timestamp - def control_cash_out(self): - self.evaluate_cash_out_time() - if self.cashing_out: - self.check_executors_status() - else: - self.check_manual_cash_out() - - def evaluate_cash_out_time(self): - if self.cash_out_time and self.current_timestamp >= self.cash_out_time and not self.cashing_out: - self.logger().info("Cash out time reached. Stopping the controllers.") - for controller_id, controller in self.controllers.items(): - if controller.status == RunnableStatus.RUNNING: - self.logger().info(f"Cash out for controller {controller_id}.") - controller.stop() - self.cashing_out = True - - def check_manual_cash_out(self): + def check_manual_kill_switch(self): for controller_id, controller in self.controllers.items(): if controller.config.manual_kill_switch and controller.status == RunnableStatus.RUNNING: self.logger().info(f"Manual cash out for controller {controller_id}.") @@ -246,7 +137,7 @@ def apply_initial_setting(self): connectors_position_mode = {} for controller_id, controller in self.controllers.items(): self.max_pnl_by_controller[controller_id] = Decimal("0") - config_dict = controller.config.dict() + config_dict = controller.config.model_dump() if "connector_name" in config_dict: if self.is_perpetual(config_dict["connector_name"]): if "position_mode" in config_dict: @@ -256,3 +147,18 @@ def apply_initial_setting(self): trading_pair=config_dict["trading_pair"]) for connector_name, position_mode in connectors_position_mode.items(): self.connectors[connector_name].set_position_mode(position_mode) + + def did_fail_order(self, order_failed_event: MarketOrderFailureEvent): + """ + Handle order failure events by logging the error and stopping the strategy if necessary. + """ + if "position side" in order_failed_event.error_message.lower(): + connectors_position_mode = {} + for controller_id, controller in self.controllers.items(): + config_dict = controller.config.model_dump() + if "connector_name" in config_dict: + if self.is_perpetual(config_dict["connector_name"]): + if "position_mode" in config_dict: + connectors_position_mode[config_dict["connector_name"]] = config_dict["position_mode"] + for connector_name, position_mode in connectors_position_mode.items(): + self.connectors[connector_name].set_position_mode(position_mode) diff --git a/config.py b/config.py index f37dae8f..97abdd45 100644 --- a/config.py +++ b/config.py @@ -1,15 +1,120 @@ -import os +from typing import List +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict -from dotenv import load_dotenv -load_dotenv() +class BrokerSettings(BaseSettings): + """MQTT Broker configuration for bot communication.""" + + host: str = Field(default="localhost", description="MQTT broker host") + port: int = Field(default=1883, description="MQTT broker port") + username: str = Field(default="admin", description="MQTT broker username") + password: str = Field(default="password", description="MQTT broker password") -CONTROLLERS_PATH = "bots/conf/controllers" -CONTROLLERS_MODULE = "bots.controllers" -CONFIG_PASSWORD = os.getenv("CONFIG_PASSWORD", "a") -BROKER_HOST = os.getenv("BROKER_HOST", "localhost") -BROKER_PORT = int(os.getenv("BROKER_PORT", 1883)) -BROKER_USERNAME = os.getenv("BROKER_USERNAME", "admin") -BROKER_PASSWORD = os.getenv("BROKER_PASSWORD", "password") -PASSWORD_VERIFICATION_PATH = "bots/credentials/master_account/.password_verification" -BANNED_TOKENS = os.getenv("BANNED_TOKENS", "NAV,ARS,ETHW,ETHF").split(",") \ No newline at end of file + model_config = SettingsConfigDict(env_prefix="BROKER_", extra="ignore") + + +class DatabaseSettings(BaseSettings): + """Database configuration.""" + + url: str = Field( + default="postgresql+asyncpg://hbot:hummingbot-api@localhost:5432/hummingbot_api", + description="Database connection URL" + ) + + model_config = SettingsConfigDict(env_prefix="DATABASE_", extra="ignore") + + +class MarketDataSettings(BaseSettings): + """Market data feed manager configuration.""" + + cleanup_interval: int = Field( + default=300, + description="How often to run feed cleanup in seconds" + ) + feed_timeout: int = Field( + default=600, + description="How long to keep unused feeds alive in seconds" + ) + + model_config = SettingsConfigDict(env_prefix="MARKET_DATA_", extra="ignore") + + +class SecuritySettings(BaseSettings): + """Security and authentication configuration.""" + + username: str = Field(default="admin", description="API basic auth username") + password: str = Field(default="admin", description="API basic auth password") + debug_mode: bool = Field(default=False, description="Enable debug mode (disables auth)") + config_password: str = Field(default="a", description="Bot configuration encryption password") + + model_config = SettingsConfigDict( + env_prefix="", + extra="ignore" # Ignore extra environment variables + ) + + +class AWSSettings(BaseSettings): + """AWS configuration for S3 archiving.""" + + api_key: str = Field(default="", description="AWS API key") + secret_key: str = Field(default="", description="AWS secret key") + s3_default_bucket_name: str = Field(default="", description="Default S3 bucket for archiving") + + model_config = SettingsConfigDict(env_prefix="AWS_", extra="ignore") + + +class AppSettings(BaseSettings): + """Main application settings.""" + + # Static paths + controllers_path: str = "bots/conf/controllers" + controllers_module: str = "bots.controllers" + password_verification_path: str = "credentials/master_account/.password_verification" + + # Environment-configurable settings + logfire_environment: str = Field( + default="dev", + description="Logfire environment name" + ) + + # Account state update interval + account_update_interval: int = Field( + default=5, + description="How often to update account states in minutes" + ) + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore" + ) + + +class Settings(BaseSettings): + """Combined application settings.""" + + broker: BrokerSettings = Field(default_factory=BrokerSettings) + database: DatabaseSettings = Field(default_factory=DatabaseSettings) + market_data: MarketDataSettings = Field(default_factory=MarketDataSettings) + security: SecuritySettings = Field(default_factory=SecuritySettings) + aws: AWSSettings = Field(default_factory=AWSSettings) + app: AppSettings = Field(default_factory=AppSettings) + + # Direct banned_tokens field to handle env parsing + banned_tokens: List[str] = Field( + default=["NAV", "ARS", "ETHW", "ETHF"], + description="List of banned trading tokens" + ) + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + env_prefix="", + extra="ignore" + ) + + +# Create global settings instance +settings = Settings() diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 00000000..0690b994 --- /dev/null +++ b/database/__init__.py @@ -0,0 +1,8 @@ +from .models import AccountState, TokenState, Order, Trade, PositionSnapshot, FundingPayment, BotRun, Base +from .connection import AsyncDatabaseManager +from .repositories import AccountRepository, BotRunRepository +from .repositories.order_repository import OrderRepository +from .repositories.trade_repository import TradeRepository +from .repositories.funding_repository import FundingRepository + +__all__ = ["AccountState", "TokenState", "Order", "Trade", "PositionSnapshot", "FundingPayment", "BotRun", "Base", "AsyncDatabaseManager", "AccountRepository", "BotRunRepository", "OrderRepository", "TradeRepository", "FundingRepository"] \ No newline at end of file diff --git a/database/connection.py b/database/connection.py new file mode 100644 index 00000000..6cca9bb8 --- /dev/null +++ b/database/connection.py @@ -0,0 +1,112 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker + +from .models import Base + +logger = logging.getLogger(__name__) + + +class AsyncDatabaseManager: + def __init__(self, database_url: str): + # Convert postgresql:// to postgresql+asyncpg:// for async support + if database_url.startswith("postgresql://"): + database_url = database_url.replace("postgresql://", "postgresql+asyncpg://") + + self.engine = create_async_engine( + database_url, + # Connection pool settings for async + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=1800, # Recycle connections after 30 minutes + pool_pre_ping=True, # Test connections before using them + # Engine settings + echo=False, # Set to True for SQL query logging + echo_pool=False, # Set to True for connection pool logging + # Connection arguments for asyncpg + connect_args={ + "server_settings": {"application_name": "hummingbot-api"}, + "command_timeout": 60, + } + ) + self.async_session = async_sessionmaker( + self.engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async def create_tables(self): + """Create all tables defined in the models.""" + try: + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Drop Hummingbot's native tables since we use our custom orders/trades tables + await self._drop_hummingbot_tables(conn) + + logger.info("Database tables created successfully") + except Exception as e: + logger.error(f"Failed to create database tables: {e}") + raise + + async def _drop_hummingbot_tables(self, conn): + """Drop Hummingbot's native database tables since we use custom ones.""" + hummingbot_tables = [ + "hummingbot_orders", + "hummingbot_trade_fills", + "hummingbot_order_status" + ] + + for table_name in hummingbot_tables: + try: + await conn.execute(text(f"DROP TABLE IF EXISTS {table_name}")) + logger.info(f"Dropped Hummingbot table: {table_name}") + except Exception as e: + logger.debug(f"Could not drop table {table_name}: {e}") # Use debug since table might not exist + + async def close(self): + """Close all database connections.""" + await self.engine.dispose() + logger.info("Database connections closed") + + def get_session(self) -> AsyncSession: + """Get a new database session.""" + return self.async_session() + + @asynccontextmanager + async def get_session_context(self) -> AsyncGenerator[AsyncSession, None]: + """ + Get a database session with automatic error handling and cleanup. + + Usage: + async with db_manager.get_session_context() as session: + # Use session here + """ + async with self.async_session() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + async def health_check(self) -> bool: + """ + Check if the database connection is healthy. + + Returns: + bool: True if connection is healthy, False otherwise. + """ + try: + async with self.engine.connect() as conn: + await conn.execute(text("SELECT 1")) + return True + except Exception as e: + logger.error(f"Database health check failed: {e}") + return False \ No newline at end of file diff --git a/database/models.py b/database/models.py new file mode 100644 index 00000000..e95b07b0 --- /dev/null +++ b/database/models.py @@ -0,0 +1,211 @@ +from sqlalchemy import ( + TIMESTAMP, + Column, + ForeignKey, + Integer, + Numeric, + String, + Text, + func, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class AccountState(Base): + __tablename__ = "account_states" + + id = Column(Integer, primary_key=True, index=True) + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + token_states = relationship("TokenState", back_populates="account_state", cascade="all, delete-orphan") + + +class TokenState(Base): + __tablename__ = "token_states" + + id = Column(Integer, primary_key=True, index=True) + account_state_id = Column(Integer, ForeignKey("account_states.id"), nullable=False) + token = Column(String, nullable=False, index=True) + units = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=False) + value = Column(Numeric(precision=30, scale=18), nullable=False) + available_units = Column(Numeric(precision=30, scale=18), nullable=False) + + account_state = relationship("AccountState", back_populates="token_states") + + +class Order(Base): + __tablename__ = "orders" + + id = Column(Integer, primary_key=True, index=True) + # Order identification + client_order_id = Column(String, nullable=False, unique=True, index=True) + exchange_order_id = Column(String, nullable=True, index=True) + + # Timestamps + created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + + # Account and connector info + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + # Order details + trading_pair = Column(String, nullable=False, index=True) + trade_type = Column(String, nullable=False) # BUY, SELL + order_type = Column(String, nullable=False) # LIMIT, MARKET, LIMIT_MAKER + amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=True) # Null for market orders + + # Order status and execution + status = Column(String, nullable=False, default="SUBMITTED", index=True) # SUBMITTED, OPEN, FILLED, CANCELLED, FAILED + filled_amount = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + average_fill_price = Column(Numeric(precision=30, scale=18), nullable=True) + + # Fee information + fee_paid = Column(Numeric(precision=30, scale=18), default=0, nullable=True) + fee_currency = Column(String, nullable=True) + + # Additional metadata + error_message = Column(Text, nullable=True) + + # Relationships for future enhancements + trades = relationship("Trade", back_populates="order", cascade="all, delete-orphan") + + +class Trade(Base): + __tablename__ = "trades" + + id = Column(Integer, primary_key=True, index=True) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=False) + + # Trade identification + trade_id = Column(String, nullable=False, unique=True, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + + # Trade details + trading_pair = Column(String, nullable=False, index=True) + trade_type = Column(String, nullable=False) # BUY, SELL + amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=False) + + # Fee information + fee_paid = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + fee_currency = Column(String, nullable=True) + + # Relationship + order = relationship("Order", back_populates="trades") + + +class PositionSnapshot(Base): + __tablename__ = "position_snapshots" + + id = Column(Integer, primary_key=True, index=True) + + # Position identification + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + trading_pair = Column(String, nullable=False, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + + # Real-time exchange data (from connector.account_positions) + side = Column(String, nullable=False) # LONG, SHORT + exchange_size = Column(Numeric(precision=30, scale=18), nullable=False) # Size from exchange + entry_price = Column(Numeric(precision=30, scale=18), nullable=True) # Average entry price + mark_price = Column(Numeric(precision=30, scale=18), nullable=True) # Current mark price + + # Real-time PnL data (can't be derived from trades alone) + unrealized_pnl = Column(Numeric(precision=30, scale=18), nullable=True) # From exchange + percentage_pnl = Column(Numeric(precision=10, scale=6), nullable=True) # PnL percentage + + # Leverage and margin info + leverage = Column(Numeric(precision=10, scale=2), nullable=True) # Position leverage + initial_margin = Column(Numeric(precision=30, scale=18), nullable=True) # Initial margin + maintenance_margin = Column(Numeric(precision=30, scale=18), nullable=True) # Maintenance margin + + # Fee tracking (exchange provides cumulative data) + cumulative_funding_fees = Column(Numeric(precision=30, scale=18), nullable=False, default=0) # Funding fees + fee_currency = Column(String, nullable=True) # Fee currency (usually USDT) + + # Reconciliation fields (calculated from our trade data) + calculated_size = Column(Numeric(precision=30, scale=18), nullable=True) # Size from our trades + calculated_entry_price = Column(Numeric(precision=30, scale=18), nullable=True) # Entry from our trades + size_difference = Column(Numeric(precision=30, scale=18), nullable=True) # Difference for reconciliation + + # Additional metadata + exchange_position_id = Column(String, nullable=True, index=True) # Exchange position ID + is_reconciled = Column(String, nullable=False, default="PENDING") # RECONCILED, MISMATCH, PENDING + + +class FundingPayment(Base): + __tablename__ = "funding_payments" + + id = Column(Integer, primary_key=True, index=True) + + # Payment identification + funding_payment_id = Column(String, nullable=False, unique=True, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + + # Account and connector info + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + # Funding details + trading_pair = Column(String, nullable=False, index=True) + funding_rate = Column(Numeric(precision=20, scale=18), nullable=False) # Funding rate + funding_payment = Column(Numeric(precision=30, scale=18), nullable=False) # Payment amount + fee_currency = Column(String, nullable=False) # Payment currency (usually USDT) + + # Position association + position_size = Column(Numeric(precision=30, scale=18), nullable=True) # Position size at time of payment + position_side = Column(String, nullable=True) # LONG, SHORT + + # Additional metadata + exchange_funding_id = Column(String, nullable=True, index=True) # Exchange funding ID + + +class BotRun(Base): + __tablename__ = "bot_runs" + + id = Column(Integer, primary_key=True, index=True) + + # Bot identification + bot_name = Column(String, nullable=False, index=True) + instance_name = Column(String, nullable=False, index=True) + + # Deployment info + deployed_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + strategy_type = Column(String, nullable=False, index=True) # 'script' or 'controller' + strategy_name = Column(String, nullable=False, index=True) + config_name = Column(String, nullable=True, index=True) + + # Runtime tracking + stopped_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) + + # Status tracking + deployment_status = Column(String, nullable=False, default="DEPLOYED", index=True) # DEPLOYED, FAILED, ARCHIVED + run_status = Column(String, nullable=False, default="CREATED", index=True) # CREATED, RUNNING, STOPPED, ERROR + + # Configuration and final state + deployment_config = Column(Text, nullable=True) # JSON of full deployment config + final_status = Column(Text, nullable=True) # JSON of final bot state, performance, etc. + + # Account info + account_name = Column(String, nullable=False, index=True) + + # Metadata + image_version = Column(String, nullable=True, index=True) + error_message = Column(Text, nullable=True) + + diff --git a/database/repositories/__init__.py b/database/repositories/__init__.py new file mode 100644 index 00000000..362ea052 --- /dev/null +++ b/database/repositories/__init__.py @@ -0,0 +1,4 @@ +from .account_repository import AccountRepository +from .bot_run_repository import BotRunRepository + +__all__ = ["AccountRepository", "BotRunRepository"] \ No newline at end of file diff --git a/database/repositories/account_repository.py b/database/repositories/account_repository.py new file mode 100644 index 00000000..1e3438ce --- /dev/null +++ b/database/repositories/account_repository.py @@ -0,0 +1,374 @@ +from datetime import datetime +from decimal import Decimal +from typing import Dict, List, Optional, Tuple +import base64 +import json + +from sqlalchemy import desc, select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from database import AccountState, TokenState + + +class AccountRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def save_account_state(self, account_name: str, connector_name: str, tokens_info: List[Dict], + snapshot_timestamp: Optional[datetime] = None) -> AccountState: + """ + Save account state with token information to the database. + If snapshot_timestamp is provided, use it instead of server default. + """ + account_state_data = { + "account_name": account_name, + "connector_name": connector_name + } + + # If a specific timestamp is provided, use it instead of server default + if snapshot_timestamp: + account_state_data["timestamp"] = snapshot_timestamp + + account_state = AccountState(**account_state_data) + + self.session.add(account_state) + await self.session.flush() # Get the ID + + for token_info in tokens_info: + token_state = TokenState( + account_state_id=account_state.id, + token=token_info["token"], + units=Decimal(str(token_info["units"])), + price=Decimal(str(token_info["price"])), + value=Decimal(str(token_info["value"])), + available_units=Decimal(str(token_info["available_units"])) + ) + self.session.add(token_state) + + await self.session.commit() + return account_state + + async def get_latest_account_states(self) -> Dict[str, Dict[str, List[Dict]]]: + """ + Get the latest account states for all accounts and connectors. + """ + # Get the latest timestamp for each account-connector combination + subquery = ( + select( + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name) + .subquery() + ) + + # Get the full records for the latest timestamps + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .join( + subquery, + (AccountState.account_name == subquery.c.account_name) & + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + ) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + # Convert to the expected format + accounts_state = {} + for account_state in account_states: + if account_state.account_name not in accounts_state: + accounts_state[account_state.account_name] = {} + + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + accounts_state[account_state.account_name][account_state.connector_name] = token_info + + return accounts_state + + async def get_account_state_history(self, + limit: Optional[int] = None, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None) -> Tuple[List[Dict], Optional[str], bool]: + """ + Get historical account states with cursor-based pagination. + + Returns: + Tuple of (data, next_cursor, has_more) + """ + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .order_by(desc(AccountState.timestamp)) + ) + + # Apply filters + if account_name: + query = query.filter(AccountState.account_name == account_name) + if connector_name: + query = query.filter(AccountState.connector_name == connector_name) + if start_time: + query = query.filter(AccountState.timestamp >= start_time) + if end_time: + query = query.filter(AccountState.timestamp <= end_time) + + # Handle cursor-based pagination + if cursor: + try: + cursor_time = datetime.fromisoformat(cursor.replace('Z', '+00:00')) + query = query.filter(AccountState.timestamp < cursor_time) + except (ValueError, TypeError): + # Invalid cursor, ignore it + pass + + # Fetch limit + 1 to check if there are more records + fetch_limit = limit + 1 if limit else 101 + query = query.limit(fetch_limit) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + # Check if there are more records + has_more = len(account_states) == fetch_limit + if has_more: + account_states = account_states[:-1] # Remove the extra record + + # Generate next cursor + next_cursor = None + if has_more and account_states: + next_cursor = account_states[-1].timestamp.isoformat() + + # Format response - Group by minute to aggregate account/connector states + minute_groups = {} + for account_state in account_states: + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + # Round timestamp to the nearest minute for grouping + minute_timestamp = account_state.timestamp.replace(second=0, microsecond=0) + minute_key = minute_timestamp.isoformat() + + # Initialize minute group if it doesn't exist + if minute_key not in minute_groups: + minute_groups[minute_key] = { + "timestamp": minute_key, + "state": {} + } + + # Add account/connector to the minute group + if account_state.account_name not in minute_groups[minute_key]["state"]: + minute_groups[minute_key]["state"][account_state.account_name] = {} + + minute_groups[minute_key]["state"][account_state.account_name][account_state.connector_name] = token_info + + # Convert to list and maintain chronological order (most recent first) + history = list(minute_groups.values()) + history.sort(key=lambda x: x["timestamp"], reverse=True) + + return history, next_cursor, has_more + + async def get_account_current_state(self, account_name: str) -> Dict[str, List[Dict]]: + """ + Get the current state for a specific account. + """ + subquery = ( + select( + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .filter(AccountState.account_name == account_name) + .group_by(AccountState.connector_name) + .subquery() + ) + + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .join( + subquery, + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .filter(AccountState.account_name == account_name) + ) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + state = {} + for account_state in account_states: + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + state[account_state.connector_name] = token_info + + return state + + async def get_connector_current_state(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get the current state for a specific connector. + """ + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .filter( + AccountState.account_name == account_name, + AccountState.connector_name == connector_name + ) + .order_by(desc(AccountState.timestamp)) + .limit(1) + ) + + result = await self.session.execute(query) + account_state = result.unique().scalar_one_or_none() + + if not account_state: + return [] + + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + return token_info + + async def get_all_unique_tokens(self) -> List[str]: + """ + Get all unique tokens across all accounts and connectors. + """ + query = ( + select(TokenState.token) + .distinct() + .order_by(TokenState.token) + ) + + result = await self.session.execute(query) + tokens = result.scalars().all() + + return list(tokens) + + async def get_token_current_state(self, token: str) -> List[Dict]: + """ + Get current state of a specific token across all accounts. + """ + # Get latest timestamps for each account-connector combination + subquery = ( + select( + AccountState.id, + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name, AccountState.id) + .subquery() + ) + + query = ( + select(TokenState, AccountState.account_name, AccountState.connector_name) + .join(AccountState) + .join( + subquery, + (AccountState.id == subquery.c.id) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .filter(TokenState.token == token) + ) + + result = await self.session.execute(query) + token_states = result.all() + + states = [] + for token_state, account_name, connector_name in token_states: + states.append({ + "account_name": account_name, + "connector_name": connector_name, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + return states + + async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict: + """ + Get total portfolio value, optionally filtered by account. + """ + # Get latest timestamps + subquery = ( + select( + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name) + ) + + if account_name: + subquery = subquery.filter(AccountState.account_name == account_name) + + subquery = subquery.subquery() + + # Get token values + query = ( + select( + AccountState.account_name, + func.sum(TokenState.value).label("total_value") + ) + .join(TokenState) + .join( + subquery, + (AccountState.account_name == subquery.c.account_name) & + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .group_by(AccountState.account_name) + ) + + result = await self.session.execute(query) + values = result.all() + + portfolio = { + "accounts": {}, + "total_value": 0 + } + + for account, value in values: + portfolio["accounts"][account] = float(value or 0) + portfolio["total_value"] += float(value or 0) + + return portfolio \ No newline at end of file diff --git a/database/repositories/bot_run_repository.py b/database/repositories/bot_run_repository.py new file mode 100644 index 00000000..3999389e --- /dev/null +++ b/database/repositories/bot_run_repository.py @@ -0,0 +1,191 @@ +import json +from datetime import datetime, timezone +from typing import Dict, List, Optional, Any + +from sqlalchemy import desc, select, and_, or_, func +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import BotRun + + +class BotRunRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_bot_run( + self, + bot_name: str, + instance_name: str, + strategy_type: str, # 'script' or 'controller' + strategy_name: str, + account_name: str, + config_name: Optional[str] = None, + image_version: Optional[str] = None, + deployment_config: Optional[Dict[str, Any]] = None + ) -> BotRun: + """Create a new bot run record.""" + bot_run = BotRun( + bot_name=bot_name, + instance_name=instance_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + config_name=config_name, + account_name=account_name, + image_version=image_version, + deployment_config=json.dumps(deployment_config) if deployment_config else None, + deployment_status="DEPLOYED", + run_status="CREATED" + ) + + self.session.add(bot_run) + await self.session.flush() + await self.session.refresh(bot_run) + return bot_run + + + async def update_bot_run_stopped( + self, + bot_name: str, + final_status: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None + ) -> Optional[BotRun]: + """Mark a bot run as stopped and save final status.""" + stmt = select(BotRun).where( + and_( + BotRun.bot_name == bot_name, + or_(BotRun.run_status == "RUNNING", BotRun.run_status == "CREATED") + ) + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + bot_run = result.scalar_one_or_none() + + if bot_run: + bot_run.run_status = "STOPPED" if not error_message else "ERROR" + bot_run.stopped_at = datetime.utcnow() + bot_run.final_status = json.dumps(final_status) if final_status else None + bot_run.error_message = error_message + await self.session.flush() + await self.session.refresh(bot_run) + + return bot_run + + async def update_bot_run_archived(self, bot_name: str) -> Optional[BotRun]: + """Mark a bot run as archived.""" + stmt = select(BotRun).where( + BotRun.bot_name == bot_name + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + bot_run = result.scalar_one_or_none() + + if bot_run: + bot_run.deployment_status = "ARCHIVED" + bot_run.stopped_at = datetime.now(timezone.utc) + await self.session.flush() + await self.session.refresh(bot_run) + + return bot_run + + async def get_bot_runs( + self, + bot_name: Optional[str] = None, + account_name: Optional[str] = None, + strategy_type: Optional[str] = None, + strategy_name: Optional[str] = None, + run_status: Optional[str] = None, + deployment_status: Optional[str] = None, + limit: int = 100, + offset: int = 0 + ) -> List[BotRun]: + """Get bot runs with optional filters.""" + stmt = select(BotRun) + + conditions = [] + if bot_name: + conditions.append(BotRun.bot_name == bot_name) + if account_name: + conditions.append(BotRun.account_name == account_name) + if strategy_type: + conditions.append(BotRun.strategy_type == strategy_type) + if strategy_name: + conditions.append(BotRun.strategy_name == strategy_name) + if run_status: + conditions.append(BotRun.run_status == run_status) + if deployment_status: + conditions.append(BotRun.deployment_status == deployment_status) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(BotRun.deployed_at)).limit(limit).offset(offset) + + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_bot_run_by_id(self, bot_run_id: int) -> Optional[BotRun]: + """Get a specific bot run by ID.""" + stmt = select(BotRun).where(BotRun.id == bot_run_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_latest_bot_run(self, bot_name: str) -> Optional[BotRun]: + """Get the latest bot run for a specific bot.""" + stmt = select(BotRun).where( + BotRun.bot_name == bot_name + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_active_bot_runs(self) -> List[BotRun]: + """Get all currently active (running) bot runs.""" + stmt = select(BotRun).where( + and_( + BotRun.run_status == "RUNNING", + BotRun.deployment_status == "DEPLOYED" + ) + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_bot_run_stats(self) -> Dict[str, Any]: + """Get statistics about bot runs.""" + # Total runs + total_stmt = select(func.count(BotRun.id)) + total_result = await self.session.execute(total_stmt) + total_runs = total_result.scalar() + + # Active runs + active_stmt = select(func.count(BotRun.id)).where( + and_( + BotRun.run_status == "RUNNING", + BotRun.deployment_status == "DEPLOYED" + ) + ) + active_result = await self.session.execute(active_stmt) + active_runs = active_result.scalar() + + # Runs by strategy type + strategy_stmt = select( + BotRun.strategy_type, + func.count(BotRun.id).label('count') + ).group_by(BotRun.strategy_type) + strategy_result = await self.session.execute(strategy_stmt) + strategy_counts = {row.strategy_type: row.count for row in strategy_result} + + # Runs by status + status_stmt = select( + BotRun.run_status, + func.count(BotRun.id).label('count') + ).group_by(BotRun.run_status) + status_result = await self.session.execute(status_stmt) + status_counts = {row.run_status: row.count for row in status_result} + + return { + "total_runs": total_runs, + "active_runs": active_runs, + "strategy_type_counts": strategy_counts, + "status_counts": status_counts + } \ No newline at end of file diff --git a/database/repositories/funding_repository.py b/database/repositories/funding_repository.py new file mode 100644 index 00000000..e9b8dd42 --- /dev/null +++ b/database/repositories/funding_repository.py @@ -0,0 +1,84 @@ +from datetime import datetime +from typing import Dict, List, Optional +from decimal import Decimal + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import FundingPayment + + +class FundingRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_funding_payment(self, funding_data: Dict) -> FundingPayment: + """Create a new funding payment record.""" + funding = FundingPayment(**funding_data) + self.session.add(funding) + await self.session.flush() # Get the ID + return funding + + async def get_funding_payments(self, account_name: str, connector_name: str = None, + trading_pair: str = None, limit: int = 100) -> List[FundingPayment]: + """Get funding payments with optional filters.""" + query = select(FundingPayment).where(FundingPayment.account_name == account_name) + + if connector_name: + query = query.where(FundingPayment.connector_name == connector_name) + if trading_pair: + query = query.where(FundingPayment.trading_pair == trading_pair) + + query = query.order_by(FundingPayment.timestamp.desc()).limit(limit) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_total_funding_fees(self, account_name: str, connector_name: str, + trading_pair: str) -> Dict: + """Get total funding fees for a specific trading pair.""" + query = select(FundingPayment).where( + FundingPayment.account_name == account_name, + FundingPayment.connector_name == connector_name, + FundingPayment.trading_pair == trading_pair + ) + + result = await self.session.execute(query) + payments = result.scalars().all() + + total_funding = Decimal('0') + payment_count = 0 + + for payment in payments: + total_funding += Decimal(str(payment.funding_payment)) + payment_count += 1 + + return { + "total_funding_fees": float(total_funding), + "payment_count": payment_count, + "fee_currency": payments[0].fee_currency if payments else None + } + + async def funding_payment_exists(self, funding_payment_id: str) -> bool: + """Check if a funding payment already exists.""" + result = await self.session.execute( + select(FundingPayment).where(FundingPayment.funding_payment_id == funding_payment_id) + ) + return result.scalar_one_or_none() is not None + + def to_dict(self, funding: FundingPayment) -> Dict: + """Convert FundingPayment model to dictionary format.""" + return { + "id": funding.id, + "funding_payment_id": funding.funding_payment_id, + "timestamp": funding.timestamp.isoformat(), + "account_name": funding.account_name, + "connector_name": funding.connector_name, + "trading_pair": funding.trading_pair, + "funding_rate": float(funding.funding_rate), + "funding_payment": float(funding.funding_payment), + "fee_currency": funding.fee_currency, + "position_size": float(funding.position_size) if funding.position_size else None, + "position_side": funding.position_side, + "exchange_funding_id": funding.exchange_funding_id, + } \ No newline at end of file diff --git a/database/repositories/order_repository.py b/database/repositories/order_repository.py new file mode 100644 index 00000000..3bf7ee21 --- /dev/null +++ b/database/repositories/order_repository.py @@ -0,0 +1,178 @@ +from datetime import datetime +from typing import Dict, List, Optional +from decimal import Decimal + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import Order + + +class OrderRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_order(self, order_data: Dict) -> Order: + """Create a new order record.""" + order = Order(**order_data) + self.session.add(order) + await self.session.flush() # Get the ID + return order + + async def get_order_by_client_id(self, client_order_id: str) -> Optional[Order]: + """Get an order by its client order ID.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + return result.scalar_one_or_none() + + async def update_order_status(self, client_order_id: str, status: str, + error_message: Optional[str] = None) -> Optional[Order]: + """Update order status and optional error message.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + order = result.scalar_one_or_none() + if order: + order.status = status + if error_message: + order.error_message = error_message + await self.session.flush() + return order + + async def update_order_fill(self, client_order_id: str, filled_amount: Decimal, + average_fill_price: Decimal, fee_paid: Decimal = None, + fee_currency: str = None, exchange_order_id: str = None) -> Optional[Order]: + """Update order with fill information.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + order = result.scalar_one_or_none() + if order: + # Add to existing filled amount instead of replacing + previous_filled = Decimal(str(order.filled_amount or 0)) + order.filled_amount = float(previous_filled + filled_amount) + + # Update average price (simplified - use latest fill price) + order.average_fill_price = float(average_fill_price) + + # Add to existing fees + if fee_paid is not None: + previous_fee = Decimal(str(order.fee_paid or 0)) + order.fee_paid = float(previous_fee + fee_paid) + if fee_currency: + order.fee_currency = fee_currency + if exchange_order_id: + order.exchange_order_id = exchange_order_id + + # Update status based on total filled amount + total_filled = Decimal(str(order.filled_amount)) + if total_filled >= Decimal(str(order.amount)): + order.status = "FILLED" + elif total_filled > 0: + order.status = "PARTIALLY_FILLED" + + await self.session.flush() + return order + + async def get_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Order]: + """Get orders with filtering and pagination.""" + query = select(Order) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Order.trading_pair == trading_pair) + if status: + query = query.where(Order.status == status) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Order.created_at >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Order.created_at <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Order.created_at.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_active_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None) -> List[Order]: + """Get active orders (SUBMITTED, OPEN, PARTIALLY_FILLED).""" + query = select(Order).where( + Order.status.in_(["SUBMITTED", "OPEN", "PARTIALLY_FILLED"]) + ) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Order.trading_pair == trading_pair) + + query = query.order_by(Order.created_at.desc()).limit(1000) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_orders_summary(self, account_name: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None) -> Dict: + """Get order summary statistics.""" + orders = await self.get_orders( + account_name=account_name, + start_time=start_time, + end_time=end_time, + limit=10000 # Get all for summary + ) + + total_orders = len(orders) + filled_orders = sum(1 for o in orders if o.status == "FILLED") + cancelled_orders = sum(1 for o in orders if o.status == "CANCELLED") + failed_orders = sum(1 for o in orders if o.status == "FAILED") + active_orders = sum(1 for o in orders if o.status in ["SUBMITTED", "OPEN", "PARTIALLY_FILLED"]) + + return { + "total_orders": total_orders, + "filled_orders": filled_orders, + "cancelled_orders": cancelled_orders, + "failed_orders": failed_orders, + "active_orders": active_orders, + "fill_rate": filled_orders / total_orders if total_orders > 0 else 0, + } + + def to_dict(self, order: Order) -> Dict: + """Convert Order model to dictionary format.""" + return { + "order_id": order.client_order_id, + "account_name": order.account_name, + "connector_name": order.connector_name, + "trading_pair": order.trading_pair, + "trade_type": order.trade_type, + "order_type": order.order_type, + "amount": float(order.amount), + "price": float(order.price) if order.price else None, + "status": order.status, + "filled_amount": float(order.filled_amount), + "average_fill_price": float(order.average_fill_price) if order.average_fill_price else None, + "fee_paid": float(order.fee_paid) if order.fee_paid else None, + "fee_currency": order.fee_currency, + "created_at": order.created_at.isoformat(), + "updated_at": order.updated_at.isoformat(), + "exchange_order_id": order.exchange_order_id, + "error_message": order.error_message, + } \ No newline at end of file diff --git a/database/repositories/trade_repository.py b/database/repositories/trade_repository.py new file mode 100644 index 00000000..d9f10ad2 --- /dev/null +++ b/database/repositories/trade_repository.py @@ -0,0 +1,103 @@ +from datetime import datetime +from typing import Dict, List, Optional + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import Trade, Order + + +class TradeRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_trade(self, trade_data: Dict) -> Trade: + """Create a new trade record.""" + trade = Trade(**trade_data) + self.session.add(trade) + await self.session.flush() # Get the ID + return trade + + async def get_trades(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + trade_type: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Trade]: + """Get trades with filtering and pagination.""" + # Join trades with orders to get account information + query = select(Trade).join(Order, Trade.order_id == Order.id) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Trade.trading_pair == trading_pair) + if trade_type: + query = query.where(Trade.trade_type == trade_type) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Trade.timestamp >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Trade.timestamp <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Trade.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_trades_with_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + trade_type: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[tuple]: + """Get trades with their associated order information.""" + # Join trades with orders to get complete information + query = select(Trade, Order).join(Order, Trade.order_id == Order.id) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Trade.trading_pair == trading_pair) + if trade_type: + query = query.where(Trade.trade_type == trade_type) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Trade.timestamp >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Trade.timestamp <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Trade.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.all() # Returns tuples of (Trade, Order) + + def to_dict(self, trade: Trade, order: Optional[Order] = None) -> Dict: + """Convert Trade model to dictionary format.""" + return { + "trade_id": trade.trade_id, + "order_id": order.client_order_id if order else None, + "account_name": order.account_name if order else None, + "connector_name": order.connector_name if order else None, + "trading_pair": trade.trading_pair, + "trade_type": trade.trade_type, + "amount": float(trade.amount), + "price": float(trade.price), + "fee_paid": float(trade.fee_paid), + "fee_currency": trade.fee_currency, + "timestamp": trade.timestamp.isoformat(), + } \ No newline at end of file diff --git a/deps.py b/deps.py new file mode 100644 index 00000000..accf6f40 --- /dev/null +++ b/deps.py @@ -0,0 +1,37 @@ +from fastapi import Request +from services.bots_orchestrator import BotsOrchestrator +from services.accounts_service import AccountsService +from services.docker_service import DockerService +from services.market_data_feed_manager import MarketDataFeedManager +from utils.bot_archiver import BotArchiver +from database import AsyncDatabaseManager + + +def get_bots_orchestrator(request: Request) -> BotsOrchestrator: + """Get BotsOrchestrator service from app state.""" + return request.app.state.bots_orchestrator + + +def get_accounts_service(request: Request) -> AccountsService: + """Get AccountsService from app state.""" + return request.app.state.accounts_service + + +def get_docker_service(request: Request) -> DockerService: + """Get DockerService from app state.""" + return request.app.state.docker_service + + +def get_market_data_feed_manager(request: Request) -> MarketDataFeedManager: + """Get MarketDataFeedManager from app state.""" + return request.app.state.market_data_feed_manager + + +def get_bot_archiver(request: Request) -> BotArchiver: + """Get BotArchiver from app state.""" + return request.app.state.bot_archiver + + +def get_database_manager(request: Request) -> AsyncDatabaseManager: + """Get AsyncDatabaseManager from app state.""" + return request.app.state.accounts_service.db_manager \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 5758652f..c6167528 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,21 +1,22 @@ -version: "3.9" - services: - backend-api: - container_name: backend-api - image: hummingbot/backend-api:latest + hummingbot-api: + container_name: hummingbot-api + image: hummingbot/hummingbot-api:latest ports: - "8000:8000" volumes: - - ./bots:/backend-api/bots + - ./bots:/hummingbot-api/bots - /var/run/docker.sock:/var/run/docker.sock + env_file: + - .env environment: + # Override specific values for Docker networking - BROKER_HOST=emqx - - BROKER_PORT=1883 - - USERNAME=admin - - PASSWORD=admin + - DATABASE_URL=postgresql+asyncpg://hbot:hummingbot-api@postgres:5432/hummingbot_api networks: - emqx-bridge + depends_on: + - postgres emqx: container_name: hummingbot-broker image: emqx:5 @@ -47,6 +48,25 @@ services: interval: 5s timeout: 25s retries: 5 + postgres: + container_name: hummingbot-postgres + image: postgres:15 + restart: unless-stopped + environment: + - POSTGRES_DB=hummingbot_api + - POSTGRES_USER=hbot + - POSTGRES_PASSWORD=hummingbot-api + volumes: + - postgres-data:/var/lib/postgresql/data + ports: + - "5432:5432" + networks: + - emqx-bridge + healthcheck: + test: ["CMD-SHELL", "pg_isready -U hbot -d hummingbot_api"] + interval: 10s + timeout: 5s + retries: 5 networks: emqx-bridge: @@ -56,3 +76,4 @@ volumes: emqx-data: { } emqx-log: { } emqx-etc: { } + postgres-data: { } diff --git a/environment.yml b/environment.yml index d41035c1..7fc7b209 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: backend-api +name: hummingbot-api channels: - conda-forge - defaults @@ -13,7 +13,16 @@ dependencies: - pip - pip: - hummingbot - - git+https://github.com/hummingbot/hbot-remote-client-py.git - flake8 - isort - pre-commit + - logfire + - logfire[fastapi] + - logfire[system-metrics] + - aiomqtt>=2.0.0 + - sqlalchemy>=2.0.0 + - asyncpg + - psycopg2-binary + - greenlet + - pydantic-settings + - logfire diff --git a/main.py b/main.py index f3bf1947..5ed4b379 100644 --- a/main.py +++ b/main.py @@ -1,35 +1,219 @@ -import os import secrets +from contextlib import asynccontextmanager from typing import Annotated +import logfire +import logging from dotenv import load_dotenv + +# Load environment variables early +load_dotenv() + +# Monkey patch save_to_yml to prevent writes to library directory +def patched_save_to_yml(yml_path, cm): + """Patched version of save_to_yml that prevents writes to library directory""" + import logging + logger = logging.getLogger(__name__) + logger.debug(f"Skipping config write to {yml_path} (patched for API mode)") + # Do nothing - this prevents the original function from trying to write to the library directory + +# Apply the patch before importing hummingbot components +from hummingbot.client.config import config_helpers +config_helpers.save_to_yml = patched_save_to_yml + +# Monkey patch start_network to conditionally start order book tracker +# async def patched_start_network(self): +# """ +# Patched version of start_network that conditionally starts the order book tracker. +# Only starts order book tracker when trading pairs are configured to avoid issues. +# """ +# import logging +# from hummingbot.core.utils.async_utils import safe_ensure_future +# +# logger = logging.getLogger(__name__) +# logger.debug(f"Starting network for {self.__class__.__name__} (patched)") +# +# # Stop any existing network first +# self._stop_network() +# +# # Check if we have trading pairs configured +# has_trading_pairs = hasattr(self, '_trading_pairs') and len(self._trading_pairs) > 0 +# +# # Start order book tracker only if we have trading pairs +# if has_trading_pairs: +# logger.debug(f"Starting order book tracker for {self.__class__.__name__} with {len(self._trading_pairs)} trading pairs") +# self.order_book_tracker.start() +# else: +# logger.debug(f"Skipping order book tracker for {self.__class__.__name__} - no trading pairs configured") +# +# # Start the essential polling tasks if trading is required +# if self.is_trading_required: +# try: +# self._trading_rules_polling_task = safe_ensure_future(self._trading_rules_polling_loop()) +# self._trading_fees_polling_task = safe_ensure_future(self._trading_fees_polling_loop()) +# self._status_polling_task = safe_ensure_future(self._status_polling_loop()) +# self._user_stream_tracker_task = self._create_user_stream_tracker_task() +# self._user_stream_event_listener_task = safe_ensure_future(self._user_stream_event_listener()) +# self._lost_orders_update_task = safe_ensure_future(self._lost_orders_update_polling_loop()) +# +# logger.debug(f"Started network tasks for {self.__class__.__name__}") +# except Exception as e: +# logger.error(f"Error starting network for {self.__class__.__name__}: {e}") +# else: +# logger.debug(f"Trading not required for {self.__class__.__name__}, skipping network start") +# +# # Apply the start_network patch - this will be applied to ExchangePyBase after import +# from hummingbot.connector.exchange_py_base import ExchangePyBase +# ExchangePyBase.start_network = patched_start_network + +from hummingbot.core.rate_oracle.rate_oracle import RateOracle + from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.middleware.cors import CORSMiddleware +from hummingbot.data_feed.market_data_provider import MarketDataProvider +from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger +from utils.security import BackendAPISecurity +from services.bots_orchestrator import BotsOrchestrator +from services.accounts_service import AccountsService +from services.docker_service import DockerService +from services.market_data_feed_manager import MarketDataFeedManager +from utils.bot_archiver import BotArchiver from routers import ( - manage_accounts, - manage_backtesting, - manage_broker_messages, - manage_databases, - manage_docker, - manage_files, - manage_market_data, - manage_performance, + accounts, + archived_bots, + backtesting, + bot_orchestration, + connectors, + controllers, + docker, + market_data, + portfolio, + scripts, + trading ) -load_dotenv() +from config import settings + + +# Set up logging configuration +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# Enable debug logging for MQTT manager +logging.getLogger('services.mqtt_manager').setLevel(logging.DEBUG) + + +# Get settings from Pydantic Settings +username = settings.security.username +password = settings.security.password +debug_mode = settings.security.debug_mode + +# Security setup security = HTTPBasic() -username = os.getenv("USERNAME", "admin") -password = os.getenv("PASSWORD", "admin") -debug_mode = os.getenv("DEBUG_MODE", False) -app = FastAPI() +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager for the FastAPI application. + Handles startup and shutdown events. + """ + # Ensure password verification file exists + if BackendAPISecurity.new_password_required(): + # Create secrets manager with CONFIG_PASSWORD + secrets_manager = ETHKeyFileSecretManger(password=settings.security.config_password) + BackendAPISecurity.store_password_verification(secrets_manager) + logging.info("Created password verification file for master_account") + + # Initialize MarketDataProvider with empty connectors (will use non-trading connectors) + market_data_provider = MarketDataProvider(connectors={}) + + # Initialize MarketDataFeedManager with lifecycle management + market_data_feed_manager = MarketDataFeedManager( + market_data_provider=market_data_provider, + rate_oracle=RateOracle.get_instance(), + cleanup_interval=settings.market_data.cleanup_interval, + feed_timeout=settings.market_data.feed_timeout + ) + + # Initialize services + bots_orchestrator = BotsOrchestrator( + broker_host=settings.broker.host, + broker_port=settings.broker.port, + broker_username=settings.broker.username, + broker_password=settings.broker.password + ) + + accounts_service = AccountsService( + account_update_interval=settings.app.account_update_interval, + market_data_feed_manager=market_data_feed_manager + ) + docker_service = DockerService() + bot_archiver = BotArchiver( + settings.aws.api_key, + settings.aws.secret_key, + settings.aws.s3_default_bucket_name + ) + + # Initialize database + await accounts_service.ensure_db_initialized() + + # Store services in app state + app.state.bots_orchestrator = bots_orchestrator + app.state.accounts_service = accounts_service + app.state.docker_service = docker_service + app.state.bot_archiver = bot_archiver + app.state.market_data_feed_manager = market_data_feed_manager + + # Start services + bots_orchestrator.start() + accounts_service.start() + market_data_feed_manager.start() + + yield + + # Shutdown services + bots_orchestrator.stop() + await accounts_service.stop() + + # Stop market data feed manager (which will stop all feeds) + market_data_feed_manager.stop() + + # Clean up docker service + docker_service.cleanup() + + # Close database connections + await accounts_service.db_manager.close() + + +# Initialize FastAPI with metadata and lifespan +app = FastAPI( + title="Hummingbot API", + description="API for managing Hummingbot trading instances", + version="1.0.0", + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Modify in production to specific origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +logfire.configure(send_to_logfire="if-token-present", environment=settings.app.logfire_environment, service_name="hummingbot-api") +logfire.instrument_fastapi(app) def auth_user( credentials: Annotated[HTTPBasicCredentials, Depends(security)], ): + """Authenticate user using HTTP Basic Auth""" current_username_bytes = credentials.username.encode("utf8") correct_username_bytes = f"{username}".encode("utf8") is_correct_username = secrets.compare_digest( @@ -46,13 +230,26 @@ def auth_user( detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}, ) + return credentials.username +# Include all routers with authentication +app.include_router(docker.router, dependencies=[Depends(auth_user)]) +app.include_router(accounts.router, dependencies=[Depends(auth_user)]) +app.include_router(connectors.router, dependencies=[Depends(auth_user)]) +app.include_router(portfolio.router, dependencies=[Depends(auth_user)]) +app.include_router(trading.router, dependencies=[Depends(auth_user)]) +app.include_router(bot_orchestration.router, dependencies=[Depends(auth_user)]) +app.include_router(controllers.router, dependencies=[Depends(auth_user)]) +app.include_router(scripts.router, dependencies=[Depends(auth_user)]) +app.include_router(market_data.router, dependencies=[Depends(auth_user)]) +app.include_router(backtesting.router, dependencies=[Depends(auth_user)]) +app.include_router(archived_bots.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_docker.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_broker_messages.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_files.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_market_data.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_backtesting.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_databases.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_performance.router, dependencies=[Depends(auth_user)]) -app.include_router(manage_accounts.router, dependencies=[Depends(auth_user)]) +@app.get("/") +async def root(): + """API root endpoint returning basic information.""" + return { + "name": "Backend API", + "version": "0.2.0", + "status": "running", + } \ No newline at end of file diff --git a/models.py b/models.py deleted file mode 100644 index 94aa10b3..00000000 --- a/models.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any, Dict, Optional - -from pydantic import BaseModel - - -class HummingbotInstanceConfig(BaseModel): - instance_name: str - credentials_profile: str - image: str = "hummingbot/hummingbot:latest" - script: Optional[str] = None - script_config: Optional[str] = None - - -class ImageName(BaseModel): - image_name: str - - -class Script(BaseModel): - name: str - content: str - - -class ScriptConfig(BaseModel): - name: str - content: Dict[str, Any] # YAML content represented as a dictionary - - -class BotAction(BaseModel): - bot_name: str - - -class StartBotAction(BotAction): - log_level: str = None - script: str = None - conf: str = None - async_backend: bool = False - - -class StopBotAction(BotAction): - skip_order_cancellation: bool = False - async_backend: bool = False - - -class ImportStrategyAction(BotAction): - strategy: str - - -class ConfigureBotAction(BotAction): - params: dict - - -class ShortcutAction(BotAction): - params: list diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 00000000..04734461 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,263 @@ +""" +Model definitions for the Backend API. + +Each model file corresponds to a router file with the same name. +Models are organized by functional domain to match the API structure. +""" + +# Bot orchestration models (bot lifecycle management) +from .bot_orchestration import ( + BotAction, + StartBotAction, + StopBotAction, + ImportStrategyAction, + ConfigureBotAction, + ShortcutAction, + BotStatus, + BotHistoryRequest, + BotHistoryResponse, + MQTTStatus, + AllBotsStatusResponse, + StopAndArchiveRequest, + StopAndArchiveResponse, + V2ScriptDeployment, + V2ControllerDeployment, +) + +# Trading models +from .trading import ( + TradeRequest, + TradeResponse, + TokenInfo, + ConnectorBalance, + AccountBalance, + PortfolioState, + OrderInfo, + ActiveOrdersResponse, + OrderSummary, + TradeInfo, + TradingRulesInfo, + OrderTypesResponse, + OrderFilterRequest, + ActiveOrderFilterRequest, + PositionFilterRequest, + FundingPaymentFilterRequest, + TradeFilterRequest, +) + +# Controller models +from .controllers import ( + ControllerType, + Controller, + ControllerResponse, + ControllerConfig, + ControllerConfigResponse, +) + +# Script models +from .scripts import ( + Script, + ScriptResponse, + ScriptConfig, + ScriptConfigResponse, +) + + +# Market data models +from .market_data import ( + CandleData, + CandlesResponse, + ActiveFeedInfo, + ActiveFeedsResponse, + MarketDataSettings, + TradingRulesResponse, + SupportedOrderTypesResponse, + # New enhanced market data models + PriceRequest, + PriceData, + PricesResponse, + FundingInfoRequest, + FundingInfoResponse, + OrderBookRequest, + OrderBookLevel, + OrderBookResponse, + OrderBookQueryRequest, + VolumeForPriceRequest, + PriceForVolumeRequest, + QuoteVolumeForPriceRequest, + PriceForQuoteVolumeRequest, + VWAPForVolumeRequest, + OrderBookQueryResult, +) + +# Account models +from .accounts import ( + LeverageRequest, + PositionModeRequest, + CredentialRequest, +) + + +# Docker models +from .docker import DockerImage + +# Backtesting models +from .backtesting import BacktestingConfig + +# Pagination models +from .pagination import PaginatedResponse, PaginationParams, TimeRangePaginationParams + +# Connector models +from .connectors import ( + ConnectorInfo, + ConnectorConfigMapResponse, + TradingRule, + ConnectorTradingRulesResponse, + ConnectorOrderTypesResponse, + ConnectorListResponse, +) + +# Portfolio models +from .portfolio import ( + TokenBalance, + ConnectorBalances, + AccountPortfolioState, + PortfolioStateResponse, + TokenDistribution, + PortfolioDistributionResponse, + AccountDistribution, + AccountsDistributionResponse, + HistoricalPortfolioState, + PortfolioHistoryFilters, +) + +# Archived bots models +from .archived_bots import ( + OrderStatus, + DatabaseStatus, + BotSummary, + PerformanceMetrics, + TradeDetail, + OrderDetail, + ExecutorInfo, + ArchivedBotListResponse, + BotPerformanceResponse, + TradeHistoryResponse, + OrderHistoryResponse, + ExecutorsResponse, +) + +__all__ = [ + # Bot orchestration models + "BotAction", + "StartBotAction", + "StopBotAction", + "ImportStrategyAction", + "ConfigureBotAction", + "ShortcutAction", + "BotStatus", + "BotHistoryRequest", + "BotHistoryResponse", + "MQTTStatus", + "AllBotsStatusResponse", + "StopAndArchiveRequest", + "StopAndArchiveResponse", + "V2ScriptDeployment", + "V2ControllerDeployment", + # Trading models + "TradeRequest", + "TradeResponse", + "TokenInfo", + "ConnectorBalance", + "AccountBalance", + "PortfolioState", + "OrderInfo", + "ActiveOrdersResponse", + "OrderSummary", + "TradeInfo", + "TradingRulesInfo", + "OrderTypesResponse", + "OrderFilterRequest", + "ActiveOrderFilterRequest", + "PositionFilterRequest", + "FundingPaymentFilterRequest", + "TradeFilterRequest", + # Controller models + "ControllerType", + "Controller", + "ControllerResponse", + "ControllerConfig", + "ControllerConfigResponse", + # Script models + "Script", + "ScriptResponse", + "ScriptConfig", + "ScriptConfigResponse", + # Market data models + "CandleData", + "CandlesResponse", + "ActiveFeedInfo", + "ActiveFeedsResponse", + "MarketDataSettings", + "TradingRulesResponse", + "SupportedOrderTypesResponse", + # New enhanced market data models + "PriceRequest", + "PriceData", + "PricesResponse", + "FundingInfoRequest", + "FundingInfoResponse", + "OrderBookRequest", + "OrderBookLevel", + "OrderBookResponse", + "OrderBookQueryRequest", + "VolumeForPriceRequest", + "PriceForVolumeRequest", + "QuoteVolumeForPriceRequest", + "PriceForQuoteVolumeRequest", + "VWAPForVolumeRequest", + "OrderBookQueryResult", + # Account models + "LeverageRequest", + "PositionModeRequest", + "CredentialRequest", + # Docker models + "DockerImage", + # Backtesting models + "BacktestingConfig", + # Pagination models + "PaginatedResponse", + "PaginationParams", + "TimeRangePaginationParams", + # Connector models + "ConnectorInfo", + "ConnectorConfigMapResponse", + "TradingRule", + "ConnectorTradingRulesResponse", + "ConnectorOrderTypesResponse", + "ConnectorListResponse", + # Portfolio models + "TokenBalance", + "ConnectorBalances", + "AccountPortfolioState", + "PortfolioStateResponse", + "TokenDistribution", + "PortfolioDistributionResponse", + "AccountDistribution", + "AccountsDistributionResponse", + "HistoricalPortfolioState", + "PortfolioHistoryFilters", + # Archived bots models + "OrderStatus", + "DatabaseStatus", + "BotSummary", + "PerformanceMetrics", + "TradeDetail", + "OrderDetail", + "ExecutorInfo", + "ArchivedBotListResponse", + "BotPerformanceResponse", + "TradeHistoryResponse", + "OrderHistoryResponse", + "ExecutorsResponse", +] \ No newline at end of file diff --git a/models/accounts.py b/models/accounts.py new file mode 100644 index 00000000..4f7b8c2a --- /dev/null +++ b/models/accounts.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, Field +from typing import Dict, Any + + +class LeverageRequest(BaseModel): + """Request model for setting leverage on perpetual connectors""" + trading_pair: str = Field(description="Trading pair (e.g., BTC-USDT)") + leverage: int = Field(description="Leverage value (typically 1-125)", ge=1, le=125) + + +class PositionModeRequest(BaseModel): + """Request model for setting position mode on perpetual connectors""" + position_mode: str = Field(description="Position mode (HEDGE or ONEWAY)") + + +class CredentialRequest(BaseModel): + """Request model for adding connector credentials""" + credentials: Dict[str, Any] = Field(description="Connector credentials dictionary") \ No newline at end of file diff --git a/models/archived_bots.py b/models/archived_bots.py new file mode 100644 index 00000000..cea57962 --- /dev/null +++ b/models/archived_bots.py @@ -0,0 +1,134 @@ +""" +Pydantic models for the archived bots router. + +These models define the request/response schemas for archived bot analysis endpoints. +""" + +from typing import Dict, List, Optional, Any +from datetime import datetime +from pydantic import BaseModel, Field +from enum import Enum + + +class OrderStatus(str, Enum): + """Order status enumeration""" + OPEN = "OPEN" + FILLED = "FILLED" + CANCELLED = "CANCELLED" + FAILED = "FAILED" + + +class DatabaseStatus(BaseModel): + """Database status information""" + db_path: str = Field(description="Path to the database file") + status: Dict[str, Any] = Field(description="Database health status") + healthy: bool = Field(description="Whether the database is healthy") + + +class BotSummary(BaseModel): + """Summary information for an archived bot""" + bot_name: str = Field(description="Name of the bot") + start_time: Optional[datetime] = Field(default=None, description="Bot start time") + end_time: Optional[datetime] = Field(default=None, description="Bot end time") + total_trades: int = Field(default=0, description="Total number of trades") + total_orders: int = Field(default=0, description="Total number of orders") + markets: List[str] = Field(default_factory=list, description="List of traded markets") + strategies: List[str] = Field(default_factory=list, description="List of strategies used") + + +class PerformanceMetrics(BaseModel): + """Performance metrics for an archived bot""" + total_pnl: float = Field(description="Total profit and loss") + total_volume: float = Field(description="Total trading volume") + avg_return: float = Field(description="Average return per trade") + win_rate: float = Field(description="Percentage of winning trades") + sharpe_ratio: Optional[float] = Field(default=None, description="Sharpe ratio") + max_drawdown: Optional[float] = Field(default=None, description="Maximum drawdown") + total_trades: int = Field(description="Total number of trades") + + +class TradeDetail(BaseModel): + """Detailed trade information""" + id: Optional[int] = Field(default=None, description="Trade ID") + config_file_path: str = Field(description="Configuration file path") + strategy: str = Field(description="Strategy name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + base_asset: str = Field(description="Base asset") + quote_asset: str = Field(description="Quote asset") + timestamp: datetime = Field(description="Trade timestamp") + order_id: str = Field(description="Order ID") + trade_type: str = Field(description="Trade type (BUY/SELL)") + price: float = Field(description="Trade price") + amount: float = Field(description="Trade amount") + trade_fee: Dict[str, float] = Field(description="Trade fees") + exchange_trade_id: str = Field(description="Exchange trade ID") + leverage: Optional[int] = Field(default=None, description="Leverage used") + position: Optional[str] = Field(default=None, description="Position type") + + +class OrderDetail(BaseModel): + """Detailed order information""" + id: Optional[int] = Field(default=None, description="Order ID") + client_order_id: str = Field(description="Client order ID") + exchange_order_id: Optional[str] = Field(default=None, description="Exchange order ID") + trading_pair: str = Field(description="Trading pair") + status: OrderStatus = Field(description="Order status") + order_type: str = Field(description="Order type") + amount: float = Field(description="Order amount") + price: Optional[float] = Field(default=None, description="Order price") + creation_timestamp: datetime = Field(description="Order creation time") + last_update_timestamp: Optional[datetime] = Field(default=None, description="Last update time") + filled_amount: Optional[float] = Field(default=None, description="Filled amount") + leverage: Optional[int] = Field(default=None, description="Leverage used") + position: Optional[str] = Field(default=None, description="Position type") + + +class ExecutorInfo(BaseModel): + """Executor information""" + controller_id: str = Field(description="Controller ID") + timestamp: datetime = Field(description="Timestamp") + type: str = Field(description="Executor type") + controller_config: Dict[str, Any] = Field(description="Controller configuration") + net_pnl_flat: float = Field(description="Net PnL in flat terms") + net_pnl_pct: float = Field(description="Net PnL percentage") + total_executors: int = Field(description="Total number of executors") + total_amount: float = Field(description="Total amount") + total_spent: float = Field(description="Total spent") + + +class ArchivedBotListResponse(BaseModel): + """Response for listing archived bots""" + bots: List[str] = Field(description="List of archived bot database paths") + count: int = Field(description="Total number of archived bots") + + +class BotPerformanceResponse(BaseModel): + """Response for bot performance analysis""" + bot_name: str = Field(description="Bot name") + metrics: PerformanceMetrics = Field(description="Performance metrics") + period_start: Optional[datetime] = Field(default=None, description="Analysis period start") + period_end: Optional[datetime] = Field(default=None, description="Analysis period end") + + +class TradeHistoryResponse(BaseModel): + """Response for trade history""" + trades: List[TradeDetail] = Field(description="List of trades") + total: int = Field(description="Total number of trades") + page: int = Field(description="Current page") + page_size: int = Field(description="Page size") + + +class OrderHistoryResponse(BaseModel): + """Response for order history""" + orders: List[OrderDetail] = Field(description="List of orders") + total: int = Field(description="Total number of orders") + page: int = Field(description="Current page") + page_size: int = Field(description="Page size") + filtered_by_status: Optional[OrderStatus] = Field(default=None, description="Status filter applied") + + +class ExecutorsResponse(BaseModel): + """Response for executors information""" + executors: List[ExecutorInfo] = Field(description="List of executors") + total: int = Field(description="Total number of executors") \ No newline at end of file diff --git a/models/backtesting.py b/models/backtesting.py new file mode 100644 index 00000000..c3cb5bd5 --- /dev/null +++ b/models/backtesting.py @@ -0,0 +1,10 @@ +from typing import Dict, Union +from pydantic import BaseModel + + +class BacktestingConfig(BaseModel): + start_time: int = 1735689600 # 2025-01-01 00:00:00 + end_time: int = 1738368000 # 2025-02-01 00:00:00 + backtesting_resolution: str = "1m" + trade_cost: float = 0.0006 + config: Union[Dict, str] \ No newline at end of file diff --git a/models/bot_orchestration.py b/models/bot_orchestration.py new file mode 100644 index 00000000..90a18323 --- /dev/null +++ b/models/bot_orchestration.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, Optional, List +from pydantic import BaseModel, Field +from enum import Enum + + +class BotAction(BaseModel): + """Base class for bot actions""" + bot_name: str = Field(description="Name of the bot instance to act upon") + + +class StartBotAction(BotAction): + """Action to start a bot""" + log_level: Optional[str] = Field(default=None, description="Logging level (DEBUG, INFO, WARNING, ERROR)") + script: Optional[str] = Field(default=None, description="Script name to run (without .py extension)") + conf: Optional[str] = Field(default=None, description="Configuration file name (without .yml extension)") + async_backend: bool = Field(default=False, description="Whether to run in async backend mode") + + +class StopBotAction(BotAction): + """Action to stop a bot""" + skip_order_cancellation: bool = Field(default=False, description="Whether to skip cancelling open orders when stopping") + async_backend: bool = Field(default=False, description="Whether to run in async backend mode") + + +class ImportStrategyAction(BotAction): + """Action to import a strategy for a bot""" + strategy: str = Field(description="Name of the strategy to import") + + +class ConfigureBotAction(BotAction): + """Action to configure bot parameters""" + params: dict = Field(description="Configuration parameters to update") + + +class ShortcutAction(BotAction): + """Action to execute bot shortcuts""" + params: list = Field(description="List of shortcut parameters") + + +class BotStatus(BaseModel): + """Status information for a bot""" + bot_name: str = Field(description="Bot name") + status: str = Field(description="Bot status (running, stopped, etc.)") + uptime: Optional[float] = Field(None, description="Bot uptime in seconds") + performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics") + + +class BotHistoryRequest(BaseModel): + """Request for bot trading history""" + bot_name: str = Field(description="Bot name") + days: int = Field(default=0, description="Number of days of history (0 for all)") + verbose: bool = Field(default=False, description="Include verbose information") + precision: Optional[int] = Field(None, description="Decimal precision for numbers") + timeout: float = Field(default=30.0, description="Request timeout in seconds") + + +class BotHistoryResponse(BaseModel): + """Response for bot trading history""" + bot_name: str = Field(description="Bot name") + history: Dict[str, Any] = Field(description="Trading history data") + status: str = Field(description="Response status") + + +class MQTTStatus(BaseModel): + """MQTT connection status""" + mqtt_connected: bool = Field(description="Whether MQTT is connected") + discovered_bots: List[str] = Field(description="List of discovered bots") + active_bots: List[str] = Field(description="List of active bots") + broker_host: str = Field(description="MQTT broker host") + broker_port: int = Field(description="MQTT broker port") + broker_username: Optional[str] = Field(None, description="MQTT broker username") + client_state: str = Field(description="MQTT client state") + + +class AllBotsStatusResponse(BaseModel): + """Response for all bots status""" + bots: List[BotStatus] = Field(description="List of bot statuses") + + +class StopAndArchiveRequest(BaseModel): + """Request for stopping and archiving a bot""" + skip_order_cancellation: bool = Field(default=True, description="Skip order cancellation") + async_backend: bool = Field(default=True, description="Use async backend") + archive_locally: bool = Field(default=True, description="Archive locally") + s3_bucket: Optional[str] = Field(None, description="S3 bucket for archiving") + timeout: float = Field(default=30.0, description="Operation timeout") + + +class StopAndArchiveResponse(BaseModel): + """Response for stop and archive operation""" + status: str = Field(description="Operation status") + message: str = Field(description="Status message") + details: Dict[str, Any] = Field(description="Operation details") + + +# Bot deployment models +class V2ScriptDeployment(BaseModel): + """Configuration for deploying a bot with a script""" + instance_name: str = Field(description="Unique name for the bot instance") + credentials_profile: str = Field(description="Name of the credentials profile to use") + image: str = Field(default="hummingbot/hummingbot:latest", description="Docker image for the Hummingbot instance") + script: Optional[str] = Field(default=None, description="Name of the script to run (without .py extension)") + script_config: Optional[str] = Field(default=None, description="Name of the script configuration file (without .yml extension)") + headless: bool = Field(default=False, description="Run in headless mode (no UI)") + + +class V2ControllerDeployment(BaseModel): + """Configuration for deploying a bot with controllers""" + instance_name: str = Field(description="Unique name for the bot instance") + credentials_profile: str = Field(description="Name of the credentials profile to use") + controllers_config: List[str] = Field(description="List of controller configuration files to use (without .yml extension)") + max_global_drawdown_quote: Optional[float] = Field(default=None, description="Maximum allowed global drawdown in quote usually USDT") + max_controller_drawdown_quote: Optional[float] = Field(default=None, description="Maximum allowed per-controller drawdown in quote usually USDT") + image: str = Field(default="hummingbot/hummingbot:latest", description="Docker image for the Hummingbot instance") + headless: bool = Field(default=False, description="Run in headless mode (no UI)") \ No newline at end of file diff --git a/models/connectors.py b/models/connectors.py new file mode 100644 index 00000000..eea431fe --- /dev/null +++ b/models/connectors.py @@ -0,0 +1,56 @@ +""" +Pydantic models for the connectors router. + +These models define the request/response schemas for connector-related endpoints. +""" + +from typing import Dict, List, Any, Optional +from pydantic import BaseModel, Field + + +class ConnectorInfo(BaseModel): + """Information about a connector""" + name: str = Field(description="Connector name") + is_perpetual: bool = Field(default=False, description="Whether the connector supports perpetual trading") + supported_order_types: Optional[List[str]] = Field(default=None, description="Supported order types") + + +class ConnectorConfigMapResponse(BaseModel): + """Response for connector configuration requirements""" + connector_name: str = Field(description="Name of the connector") + config_fields: List[str] = Field(description="List of required configuration fields") + + +class TradingRule(BaseModel): + """Trading rules for a specific trading pair""" + min_order_size: float = Field(description="Minimum order size") + max_order_size: float = Field(description="Maximum order size") + min_price_increment: float = Field(description="Minimum price increment") + min_base_amount_increment: float = Field(description="Minimum base amount increment") + min_quote_amount_increment: float = Field(description="Minimum quote amount increment") + min_notional_size: float = Field(description="Minimum notional size") + min_order_value: float = Field(description="Minimum order value") + max_price_significant_digits: float = Field(description="Maximum price significant digits") + supports_limit_orders: bool = Field(description="Whether limit orders are supported") + supports_market_orders: bool = Field(description="Whether market orders are supported") + buy_order_collateral_token: str = Field(description="Collateral token for buy orders") + sell_order_collateral_token: str = Field(description="Collateral token for sell orders") + + +class ConnectorTradingRulesResponse(BaseModel): + """Response for connector trading rules""" + connector: str = Field(description="Connector name") + trading_pairs: Optional[List[str]] = Field(default=None, description="Filtered trading pairs if provided") + rules: Dict[str, TradingRule] = Field(description="Trading rules by trading pair") + + +class ConnectorOrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +class ConnectorListResponse(BaseModel): + """Response for list of available connectors""" + connectors: List[str] = Field(description="List of available connector names") + count: int = Field(description="Total number of connectors") \ No newline at end of file diff --git a/models/controllers.py b/models/controllers.py new file mode 100644 index 00000000..a2eeff82 --- /dev/null +++ b/models/controllers.py @@ -0,0 +1,52 @@ +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field +from enum import Enum + + +class ControllerType(str, Enum): + """Types of controllers available""" + DIRECTIONAL_TRADING = "directional_trading" + MARKET_MAKING = "market_making" + GENERIC = "generic" + + +# Controller file operations +class Controller(BaseModel): + """Controller file content""" + content: str = Field(description="Controller source code") + type: Optional[ControllerType] = Field(None, description="Controller type (optional for flexibility)") + + +class ControllerResponse(BaseModel): + """Response for getting a controller""" + name: str = Field(description="Controller name") + type: str = Field(description="Controller type") + content: str = Field(description="Controller source code") + + +# Controller configuration operations +class ControllerConfig(BaseModel): + """Controller configuration""" + controller_name: str = Field(description="Controller name") + controller_type: str = Field(description="Controller type") + connector_name: Optional[str] = Field(None, description="Connector name") + trading_pair: Optional[str] = Field(None, description="Trading pair") + total_amount_quote: Optional[float] = Field(None, description="Total amount in quote currency") + + +class ControllerConfigResponse(BaseModel): + """Response for controller configuration with metadata""" + config_name: str = Field(description="Configuration name") + controller_name: str = Field(description="Controller name") + controller_type: str = Field(description="Controller type") + connector_name: Optional[str] = Field(None, description="Connector name") + trading_pair: Optional[str] = Field(None, description="Trading pair") + total_amount_quote: Optional[float] = Field(None, description="Total amount in quote currency") + error: Optional[str] = Field(None, description="Error message if config is malformed") + + +# Bot-specific controller configurations +class BotControllerConfig(BaseModel): + """Controller configuration for a specific bot""" + config_name: str = Field(description="Configuration name") + config_data: Dict[str, Any] = Field(description="Configuration data") \ No newline at end of file diff --git a/models/docker.py b/models/docker.py new file mode 100644 index 00000000..b18fb768 --- /dev/null +++ b/models/docker.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel, Field + + +class DockerImage(BaseModel): + image_name: str = Field(description="Docker image name with optional tag (e.g., 'hummingbot/hummingbot:latest')") \ No newline at end of file diff --git a/models/market_data.py b/models/market_data.py new file mode 100644 index 00000000..1f618ec2 --- /dev/null +++ b/models/market_data.py @@ -0,0 +1,168 @@ +from datetime import datetime +from decimal import Decimal +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class CandleData(BaseModel): + """Single candle data point""" + timestamp: datetime = Field(description="Candle timestamp") + open: float = Field(description="Opening price") + high: float = Field(description="Highest price") + low: float = Field(description="Lowest price") + close: float = Field(description="Closing price") + volume: float = Field(description="Trading volume") + +class CandlesConfigRequest(BaseModel): + """ + The CandlesConfig class is a data class that stores the configuration of a Candle object. + It has the following attributes: + - connector: str + - trading_pair: str + - interval: str + - max_records: int + """ + connector_name: str + trading_pair: str + interval: str = "1m" + max_records: int = 500 + +class CandlesResponse(BaseModel): + """Response for candles data""" + candles: List[CandleData] = Field(description="List of candle data") + + +class ActiveFeedInfo(BaseModel): + """Information about an active market data feed""" + connector: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + interval: str = Field(description="Candle interval") + last_access: datetime = Field(description="Last access time") + expires_at: datetime = Field(description="Expiration time") + + +class ActiveFeedsResponse(BaseModel): + """Response for active market data feeds""" + feeds: List[ActiveFeedInfo] = Field(description="List of active feeds") + + +class MarketDataSettings(BaseModel): + """Market data configuration settings""" + cleanup_interval: int = Field(description="Cleanup interval in seconds") + feed_timeout: int = Field(description="Feed timeout in seconds") + description: str = Field(description="Settings description") + + +class TradingRulesResponse(BaseModel): + """Response for trading rules""" + trading_pairs: Dict[str, Dict[str, Any]] = Field(description="Trading rules by pair") + + +class SupportedOrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +# New models for enhanced market data functionality + +class PriceRequest(BaseModel): + """Request model for getting prices""" + connector_name: str = Field(description="Name of the connector") + trading_pairs: List[str] = Field(description="List of trading pairs to get prices for") + + +class PriceData(BaseModel): + """Price data for a trading pair""" + trading_pair: str = Field(description="Trading pair") + price: float = Field(description="Current price") + timestamp: float = Field(description="Price timestamp") + + +class PricesResponse(BaseModel): + """Response for prices data""" + connector: str = Field(description="Connector name") + prices: Dict[str, float] = Field(description="Trading pair to price mapping") + timestamp: float = Field(description="Response timestamp") + + +class FundingInfoRequest(BaseModel): + """Request model for getting funding info""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair to get funding info for") + + +class FundingInfoResponse(BaseModel): + """Response for funding info""" + trading_pair: str = Field(description="Trading pair") + funding_rate: Optional[float] = Field(description="Current funding rate") + next_funding_time: Optional[float] = Field(description="Next funding time timestamp") + mark_price: Optional[float] = Field(description="Mark price") + index_price: Optional[float] = Field(description="Index price") + + +class OrderBookRequest(BaseModel): + """Request model for getting order book data""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair") + depth: int = Field(default=10, ge=1, le=1000, description="Number of price levels to return") + + +class OrderBookLevel(BaseModel): + """Single order book level""" + price: float = Field(description="Price level") + amount: float = Field(description="Amount at this price level") + + +class OrderBookResponse(BaseModel): + """Response for order book data""" + trading_pair: str = Field(description="Trading pair") + bids: List[OrderBookLevel] = Field(description="Bid levels (highest to lowest)") + asks: List[OrderBookLevel] = Field(description="Ask levels (lowest to highest)") + timestamp: float = Field(description="Snapshot timestamp") + + +class OrderBookQueryRequest(BaseModel): + """Request model for order book queries""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair") + is_buy: bool = Field(description="True for buy side, False for sell side") + + +class VolumeForPriceRequest(OrderBookQueryRequest): + """Request model for getting volume at a specific price""" + price: float = Field(description="Price to query volume for") + + +class PriceForVolumeRequest(OrderBookQueryRequest): + """Request model for getting price for a specific volume""" + volume: float = Field(description="Volume to query price for") + + +class QuoteVolumeForPriceRequest(OrderBookQueryRequest): + """Request model for getting quote volume at a specific price""" + price: float = Field(description="Price to query quote volume for") + + +class PriceForQuoteVolumeRequest(OrderBookQueryRequest): + """Request model for getting price for a specific quote volume""" + quote_volume: float = Field(description="Quote volume to query price for") + + +class VWAPForVolumeRequest(OrderBookQueryRequest): + """Request model for getting VWAP for a specific volume""" + volume: float = Field(description="Volume to calculate VWAP for") + + +class OrderBookQueryResult(BaseModel): + """Response for order book query operations""" + trading_pair: str = Field(description="Trading pair") + is_buy: bool = Field(description="Query side (buy/sell)") + query_volume: Optional[float] = Field(default=None, description="Queried volume") + query_price: Optional[float] = Field(default=None, description="Queried price") + result_price: Optional[float] = Field(default=None, description="Resulting price") + result_volume: Optional[float] = Field(default=None, description="Resulting volume") + result_quote_volume: Optional[float] = Field(default=None, description="Resulting quote volume") + average_price: Optional[float] = Field(default=None, description="Average/VWAP price") + timestamp: float = Field(description="Query timestamp") \ No newline at end of file diff --git a/models/pagination.py b/models/pagination.py new file mode 100644 index 00000000..32309218 --- /dev/null +++ b/models/pagination.py @@ -0,0 +1,37 @@ +from datetime import datetime +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, ConfigDict + + +class PaginationParams(BaseModel): + """Common pagination parameters.""" + limit: int = Field(default=100, ge=1, le=1000, description="Number of items per page") + cursor: Optional[str] = Field(None, description="Cursor for next page") + + +class TimeRangePaginationParams(BaseModel): + """Time-based pagination parameters for trading endpoints using integer timestamps.""" + limit: int = Field(default=100, ge=1, le=1000, description="Number of items per page") + start_time: Optional[int] = Field(None, description="Start time as Unix timestamp in milliseconds") + end_time: Optional[int] = Field(None, description="End time as Unix timestamp in milliseconds") + cursor: Optional[str] = Field(None, description="Cursor for next page") + + +class PaginatedResponse(BaseModel): + """Generic paginated response.""" + model_config = ConfigDict( + json_schema_extra={ + "example": { + "data": [], + "pagination": { + "limit": 100, + "has_more": True, + "next_cursor": "2024-01-10T12:00:00", + "total_count": 500 + } + } + } + ) + + data: List[Dict[str, Any]] + pagination: Dict[str, Any] \ No newline at end of file diff --git a/models/portfolio.py b/models/portfolio.py new file mode 100644 index 00000000..d2d179fc --- /dev/null +++ b/models/portfolio.py @@ -0,0 +1,97 @@ +""" +Pydantic models for the portfolio router. + +These models define the request/response schemas for portfolio-related endpoints. +""" + +from typing import Dict, List, Optional, Any +from datetime import datetime +from pydantic import BaseModel, Field + + +class TokenBalance(BaseModel): + """Token balance information""" + token: str = Field(description="Token symbol") + units: float = Field(description="Number of units held") + price: float = Field(description="Current price per unit") + value: float = Field(description="Total value (units * price)") + available_units: float = Field(description="Available units (not locked in orders)") + + +class ConnectorBalances(BaseModel): + """Balances for a specific connector""" + connector_name: str = Field(description="Name of the connector") + balances: List[TokenBalance] = Field(description="List of token balances") + total_value: float = Field(description="Total value across all tokens") + + +class AccountPortfolioState(BaseModel): + """Portfolio state for a single account""" + account_name: str = Field(description="Name of the account") + connectors: Dict[str, List[TokenBalance]] = Field(description="Balances by connector") + total_value: float = Field(description="Total account value across all connectors") + last_updated: Optional[datetime] = Field(default=None, description="Last update timestamp") + + +class PortfolioStateResponse(BaseModel): + """Response for portfolio state endpoint""" + accounts: Dict[str, Dict[str, List[Dict[str, Any]]]] = Field( + description="Portfolio state by account and connector" + ) + total_portfolio_value: Optional[float] = Field(default=None, description="Total value across all accounts") + timestamp: datetime = Field(default_factory=datetime.utcnow, description="Response timestamp") + + +class TokenDistribution(BaseModel): + """Token distribution information""" + token: str = Field(description="Token symbol") + total_value: float = Field(description="Total value of this token") + total_units: float = Field(description="Total units of this token") + percentage: float = Field(description="Percentage of total portfolio") + accounts: Dict[str, Dict[str, Any]] = Field( + description="Breakdown by account and connector" + ) + + +class PortfolioDistributionResponse(BaseModel): + """Response for portfolio distribution endpoint""" + total_portfolio_value: float = Field(description="Total portfolio value") + token_count: int = Field(description="Number of unique tokens") + distribution: List[TokenDistribution] = Field(description="Token distribution list") + account_filter: str = Field( + default="all_accounts", + description="Applied account filter (all_accounts or specific accounts)" + ) + + +class AccountDistribution(BaseModel): + """Account distribution information""" + account: str = Field(description="Account name") + total_value: float = Field(description="Total value in this account") + percentage: float = Field(description="Percentage of total portfolio") + connectors: Dict[str, Dict[str, float]] = Field( + description="Value breakdown by connector" + ) + + +class AccountsDistributionResponse(BaseModel): + """Response for accounts distribution endpoint""" + total_portfolio_value: float = Field(description="Total portfolio value") + account_count: int = Field(description="Number of accounts") + distribution: List[AccountDistribution] = Field(description="Account distribution list") + + +class HistoricalPortfolioState(BaseModel): + """Historical portfolio state entry""" + timestamp: datetime = Field(description="State timestamp") + state: Dict[str, Dict[str, List[Dict[str, Any]]]] = Field( + description="Portfolio state snapshot" + ) + total_value: Optional[float] = Field(default=None, description="Total value at this point") + + +class PortfolioHistoryFilters(BaseModel): + """Filters applied to portfolio history query""" + account_names: Optional[List[str]] = Field(default=None, description="Filtered account names") + start_time: Optional[datetime] = Field(default=None, description="Start time filter") + end_time: Optional[datetime] = Field(default=None, description="End time filter") \ No newline at end of file diff --git a/models/scripts.py b/models/scripts.py new file mode 100644 index 00000000..fd60b07d --- /dev/null +++ b/models/scripts.py @@ -0,0 +1,34 @@ +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field + + +# Script file operations +class Script(BaseModel): + """Script file content""" + content: str = Field(description="Script source code") + + +class ScriptResponse(BaseModel): + """Response for getting a script""" + name: str = Field(description="Script name") + content: str = Field(description="Script source code") + + +# Script configuration operations +class ScriptConfig(BaseModel): + """Script configuration content""" + config_name: str = Field(description="Configuration name") + script_file_name: str = Field(description="Script file name") + controllers_config: List[str] = Field(default=[], description="List of controller configurations") + candles_config: List[Dict[str, Any]] = Field(default=[], description="Candles configuration") + markets: Dict[str, Any] = Field(default={}, description="Markets configuration") + + +class ScriptConfigResponse(BaseModel): + """Response for script configuration with metadata""" + config_name: str = Field(description="Configuration name") + script_file_name: str = Field(description="Script file name") + controllers_config: List[str] = Field(default=[], description="List of controller configurations") + candles_config: List[Dict[str, Any]] = Field(default=[], description="Candles configuration") + markets: Dict[str, Any] = Field(default={}, description="Markets configuration") + error: Optional[str] = Field(None, description="Error message if config is malformed") \ No newline at end of file diff --git a/models/trading.py b/models/trading.py new file mode 100644 index 00000000..a3449e3d --- /dev/null +++ b/models/trading.py @@ -0,0 +1,209 @@ +from typing import Dict, List, Optional, Any, Literal +from pydantic import BaseModel, Field, field_validator +from decimal import Decimal +from datetime import datetime +from hummingbot.core.data_type.common import OrderType, TradeType, PositionAction +from .pagination import PaginationParams, TimeRangePaginationParams + + +class TradeRequest(BaseModel): + """Request model for placing trades""" + account_name: str = Field(description="Name of the account to trade with") + connector_name: str = Field(description="Name of the connector/exchange") + trading_pair: str = Field(description="Trading pair (e.g., BTC-USDT)") + trade_type: Literal["BUY", "SELL"] = Field(description="Whether to buy or sell") + amount: Decimal = Field(description="Amount to trade", gt=0) + order_type: Literal["LIMIT", "MARKET", "LIMIT_MAKER"] = Field(default="LIMIT", description="Type of order") + price: Optional[Decimal] = Field(default=None, description="Price for limit orders") + position_action: Literal["OPEN", "CLOSE"] = Field(default="OPEN", description="Position action for perpetual contracts (OPEN/CLOSE)") + + @field_validator('trade_type') + @classmethod + def validate_trade_type(cls, v): + """Validate that trade_type is a valid TradeType enum name.""" + try: + return TradeType[v].name + except KeyError: + valid_types = [t.name for t in TradeType] + raise ValueError(f"Invalid trade_type '{v}'. Must be one of: {valid_types}") + + @field_validator('order_type') + @classmethod + def validate_order_type(cls, v): + """Validate that order_type is a valid OrderType enum name.""" + try: + return OrderType[v].name + except KeyError: + valid_types = [t.name for t in OrderType] + raise ValueError(f"Invalid order_type '{v}'. Must be one of: {valid_types}") + + @field_validator('position_action') + @classmethod + def validate_position_action(cls, v): + """Validate that position_action is a valid PositionAction enum name.""" + try: + return PositionAction[v].name + except KeyError: + valid_actions = [a.name for a in PositionAction] + raise ValueError(f"Invalid position_action '{v}'. Must be one of: {valid_actions}") + + +class TradeResponse(BaseModel): + """Response model for trade execution""" + order_id: str = Field(description="Client order ID assigned by the connector") + account_name: str = Field(description="Account used for the trade") + connector_name: str = Field(description="Connector used for the trade") + trading_pair: str = Field(description="Trading pair") + trade_type: str = Field(description="Trade type") + amount: Decimal = Field(description="Trade amount") + order_type: str = Field(description="Order type") + price: Optional[Decimal] = Field(description="Order price") + status: str = Field(default="submitted", description="Order status") + + +class TokenInfo(BaseModel): + """Information about a token balance""" + token: str = Field(description="Token symbol") + balance: Decimal = Field(description="Token balance") + value_usd: Optional[Decimal] = Field(None, description="USD value of the balance") + + +class ConnectorBalance(BaseModel): + """Balance information for a connector""" + connector_name: str = Field(description="Name of the connector") + tokens: List[TokenInfo] = Field(description="List of token balances") + + +class AccountBalance(BaseModel): + """Balance information for an account""" + account_name: str = Field(description="Name of the account") + connectors: List[ConnectorBalance] = Field(description="List of connector balances") + + +class PortfolioState(BaseModel): + """Complete portfolio state across all accounts""" + accounts: List[AccountBalance] = Field(description="List of account balances") + timestamp: datetime = Field(description="Timestamp of the portfolio state") + + +class OrderInfo(BaseModel): + """Information about an order""" + order_id: str = Field(description="Order ID") + client_order_id: str = Field(description="Client order ID") + account_name: str = Field(description="Account name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + order_type: str = Field(description="Order type") + trade_type: str = Field(description="Trade type (BUY/SELL)") + amount: Decimal = Field(description="Order amount") + price: Optional[Decimal] = Field(description="Order price") + filled_amount: Decimal = Field(description="Filled amount") + status: str = Field(description="Order status") + creation_timestamp: datetime = Field(description="Order creation time") + last_update_timestamp: datetime = Field(description="Last update time") + + +class ActiveOrdersResponse(BaseModel): + """Response for active orders""" + orders: Dict[str, OrderInfo] = Field(description="Dictionary of active orders") + + +class OrderSummary(BaseModel): + """Summary statistics for orders""" + total_orders: int = Field(description="Total number of orders") + filled_orders: int = Field(description="Number of filled orders") + cancelled_orders: int = Field(description="Number of cancelled orders") + fill_rate: float = Field(description="Order fill rate percentage") + total_volume_base: Decimal = Field(description="Total volume in base currency") + total_volume_quote: Decimal = Field(description="Total volume in quote currency") + avg_fill_time: Optional[float] = Field(description="Average fill time in seconds") + + +class TradeInfo(BaseModel): + """Information about a trade fill""" + trade_id: str = Field(description="Trade ID") + order_id: str = Field(description="Associated order ID") + account_name: str = Field(description="Account name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + trade_type: str = Field(description="Trade type (BUY/SELL)") + amount: Decimal = Field(description="Trade amount") + price: Decimal = Field(description="Trade price") + fee: Decimal = Field(description="Trade fee") + timestamp: datetime = Field(description="Trade timestamp") + + +class TradingRulesInfo(BaseModel): + """Trading rules for a trading pair""" + trading_pair: str = Field(description="Trading pair") + min_order_size: Decimal = Field(description="Minimum order size") + max_order_size: Optional[Decimal] = Field(description="Maximum order size") + min_price_increment: Decimal = Field(description="Minimum price increment") + min_base_amount_increment: Decimal = Field(description="Minimum base amount increment") + min_quote_amount_increment: Decimal = Field(description="Minimum quote amount increment") + + +class OrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +class OrderFilterRequest(TimeRangePaginationParams): + """Request model for filtering orders with multiple criteria""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + status: Optional[str] = Field(default=None, description="Order status filter") + + +class ActiveOrderFilterRequest(PaginationParams): + """Request model for filtering active orders""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + + +class PositionFilterRequest(PaginationParams): + """Request model for filtering positions""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class FundingPaymentFilterRequest(TimeRangePaginationParams): + """Request model for filtering funding payments""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pair: Optional[str] = Field(default=None, description="Filter by trading pair") + + +class TradeFilterRequest(TimeRangePaginationParams): + """Request model for filtering trades""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + trade_types: Optional[List[str]] = Field(default=None, description="List of trade types to filter by (BUY/SELL)") + + +class PortfolioStateFilterRequest(BaseModel): + """Request model for filtering portfolio state""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class PortfolioHistoryFilterRequest(TimeRangePaginationParams): + """Request model for filtering portfolio history""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class PortfolioDistributionFilterRequest(BaseModel): + """Request model for filtering portfolio distribution""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class AccountsDistributionFilterRequest(BaseModel): + """Request model for filtering accounts distribution""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") \ No newline at end of file diff --git a/routers/accounts.py b/routers/accounts.py new file mode 100644 index 00000000..200a2930 --- /dev/null +++ b/routers/accounts.py @@ -0,0 +1,137 @@ +from typing import Dict, List, Optional +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Depends, Query +from starlette import status + +from services.accounts_service import AccountsService +from deps import get_accounts_service +from models import PaginatedResponse + +router = APIRouter(tags=["Accounts"], prefix="/accounts") + + +@router.get("/", response_model=List[str]) +async def list_accounts(accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get a list of all account names in the system. + + Returns: + List of account names + """ + return accounts_service.list_accounts() + + +@router.get("/{account_name}/credentials", response_model=List[str]) +async def list_account_credentials(account_name: str, + accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get a list of all connectors that have credentials configured for a specific account. + + Args: + account_name: Name of the account to list credentials for + + Returns: + List of connector names that have credentials configured + + Raises: + HTTPException: 404 if account not found + """ + try: + credentials = accounts_service.list_credentials(account_name) + # Remove .yml extension from filenames + return [cred.replace('.yml', '') for cred in credentials] + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/add-account", status_code=status.HTTP_201_CREATED) +async def add_account(account_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Create a new account with default configuration files. + + Args: + account_name: Name of the new account to create + + Returns: + Success message when account is created + + Raises: + HTTPException: 400 if account already exists + """ + try: + accounts_service.add_account(account_name) + return {"message": "Account added successfully."} + except FileExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/delete-account") +async def delete_account(account_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Delete an account and all its associated credentials. + + Args: + account_name: Name of the account to delete + + Returns: + Success message when account is deleted + + Raises: + HTTPException: 400 if trying to delete master account, 404 if account not found + """ + try: + if account_name == "master_account": + raise HTTPException(status_code=400, detail="Cannot delete master account.") + await accounts_service.delete_account(account_name) + return {"message": "Account deleted successfully."} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/delete-credential/{account_name}/{connector_name}") +async def delete_credential(account_name: str, connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Delete a specific connector credential for an account. + + Args: + account_name: Name of the account + connector_name: Name of the connector to delete credentials for + + Returns: + Success message when credential is deleted + + Raises: + HTTPException: 404 if credential not found + """ + try: + await accounts_service.delete_credentials(account_name, connector_name) + return {"message": "Credential deleted successfully."} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/add-credential/{account_name}/{connector_name}", status_code=status.HTTP_201_CREATED) +async def add_credential(account_name: str, connector_name: str, credentials: Dict, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Add or update connector credentials (API keys) for a specific account and connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector + credentials: Dictionary containing the connector credentials + + Returns: + Success message when credentials are added + + Raises: + HTTPException: 400 if there's an error adding the credentials + """ + try: + await accounts_service.add_credentials(account_name, connector_name, credentials) + return {"message": "Connector credentials added successfully."} + except Exception as e: + await accounts_service.delete_credentials(account_name, connector_name) + raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/archived_bots.py b/routers/archived_bots.py new file mode 100644 index 00000000..38b42d3b --- /dev/null +++ b/routers/archived_bots.py @@ -0,0 +1,297 @@ +from typing import List, Optional +from fastapi import APIRouter, HTTPException, Query + +from utils.file_system import fs_util +from utils.hummingbot_database_reader import HummingbotDatabase + +router = APIRouter(tags=["Archived Bots"], prefix="/archived-bots") + + +@router.get("/", response_model=List[str]) +async def list_databases(): + """ + List all available database files in the system. + + Returns: + List of database file paths + """ + return fs_util.list_databases() + + +@router.get("/{db_path:path}/status") +async def get_database_status(db_path: str): + """ + Get status information for a specific database. + + Args: + db_path: Path to the database file + + Returns: + Database status including table health + """ + try: + db = HummingbotDatabase(db_path) + return { + "db_path": db_path, + "status": db.status, + "healthy": db.status["general_status"] + } + except Exception as e: + raise HTTPException(status_code=404, detail=f"Database not found or error: {str(e)}") + + +@router.get("/{db_path:path}/summary") +async def get_database_summary(db_path: str): + """ + Get a summary of database contents including basic statistics. + + Args: + db_path: Full path to the database file + + Returns: + Summary statistics of the database contents + """ + try: + db = HummingbotDatabase(db_path) + + # Get basic counts + orders = db.get_orders() + trades = db.get_trade_fills() + executors = db.get_executors_data() + positions = db.get_positions() + controllers = db.get_controllers_data() + + return { + "db_path": db_path, + "total_orders": len(orders), + "total_trades": len(trades), + "total_executors": len(executors), + "total_positions": len(positions), + "total_controllers": len(controllers), + "trading_pairs": orders["trading_pair"].unique().tolist() if len(orders) > 0 else [], + "exchanges": orders["connector_name"].unique().tolist() if len(orders) > 0 else [], + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error analyzing database: {str(e)}") + + +@router.get("/{db_path:path}/performance") +async def get_database_performance(db_path: str): + """ + Get trade-based performance analysis for a bot database. + + Args: + db_path: Full path to the database file + + Returns: + Trade-based performance metrics with rolling calculations + """ + try: + db = HummingbotDatabase(db_path) + + # Use new trade-based performance calculation + performance_data = db.calculate_trade_based_performance() + + if len(performance_data) == 0: + return { + "db_path": db_path, + "error": "No trades found in database", + "performance_data": [] + } + + # Convert to records for JSON response + performance_records = performance_data.fillna(0).to_dict('records') + + # Calculate summary statistics + final_row = performance_data.iloc[-1] if len(performance_data) > 0 else {} + summary = { + "total_trades": len(performance_data), + "final_net_pnl_quote": float(final_row.get('net_pnl_quote', 0)), + "final_realized_pnl_quote": float(final_row.get('realized_trade_pnl_quote', 0)), + "final_unrealized_pnl_quote": float(final_row.get('unrealized_trade_pnl_quote', 0)), + "total_fees_quote": float(performance_data['fees_quote'].sum()), + "total_volume_quote": float(performance_data['cum_volume_quote'].iloc[-1] if len(performance_data) > 0 else 0), + "final_net_position": float(final_row.get('net_position', 0)), + "trading_pairs": performance_data['trading_pair'].unique().tolist(), + "connector_names": performance_data['connector_name'].unique().tolist() + } + + return { + "db_path": db_path, + "summary": summary, + "performance_data": performance_records + } + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error calculating performance: {str(e)}") + + +@router.get("/{db_path:path}/trades") +async def get_database_trades( + db_path: str, + limit: int = Query(default=100, description="Limit number of trades returned"), + offset: int = Query(default=0, description="Offset for pagination") +): + """ + Get trade history from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of trades to return + offset: Offset for pagination + + Returns: + List of trades with pagination info + """ + try: + db = HummingbotDatabase(db_path) + trades = db.get_trade_fills() + + # Apply pagination + total_trades = len(trades) + trades_page = trades.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "trades": trades_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_trades, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_trades + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching trades: {str(e)}") + + +@router.get("/{db_path:path}/orders") +async def get_database_orders( + db_path: str, + limit: int = Query(default=100, description="Limit number of orders returned"), + offset: int = Query(default=0, description="Offset for pagination"), + status: Optional[str] = Query(default=None, description="Filter by order status") +): + """ + Get order history from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of orders to return + offset: Offset for pagination + status: Optional status filter + + Returns: + List of orders with pagination info + """ + try: + db = HummingbotDatabase(db_path) + orders = db.get_orders() + + # Apply status filter if provided + if status: + orders = orders[orders["last_status"] == status] + + # Apply pagination + total_orders = len(orders) + orders_page = orders.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "orders": orders_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_orders, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_orders + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching orders: {str(e)}") + + +@router.get("/{db_path:path}/executors") +async def get_database_executors(db_path: str): + """ + Get executor data from a database. + + Args: + db_path: Full path to the database file + + Returns: + List of executors with their configurations and results + """ + try: + db = HummingbotDatabase(db_path) + executors = db.get_executors_data() + + return { + "db_path": db_path, + "executors": executors.fillna(0).to_dict('records'), + "total": len(executors) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching executors: {str(e)}") + + +@router.get("/{db_path:path}/positions") +async def get_database_positions( + db_path: str, + limit: int = Query(default=100, description="Limit number of positions returned"), + offset: int = Query(default=0, description="Offset for pagination") +): + """ + Get position data from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of positions to return + offset: Offset for pagination + + Returns: + List of positions with pagination info + """ + try: + db = HummingbotDatabase(db_path) + positions = db.get_positions() + + # Apply pagination + total_positions = len(positions) + positions_page = positions.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "positions": positions_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_positions, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_positions + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching positions: {str(e)}") + + +@router.get("/{db_path:path}/controllers") +async def get_database_controllers(db_path: str): + """ + Get controller data from a database. + + Args: + db_path: Full path to the database file + + Returns: + List of controllers that were running with their configurations + """ + try: + db = HummingbotDatabase(db_path) + controllers = db.get_controllers_data() + + return { + "db_path": db_path, + "controllers": controllers.fillna(0).to_dict('records'), + "total": len(controllers) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching controllers: {str(e)}") diff --git a/routers/manage_backtesting.py b/routers/backtesting.py similarity index 68% rename from routers/manage_backtesting.py rename to routers/backtesting.py index 812f1fa4..3d68ee9b 100644 --- a/routers/manage_backtesting.py +++ b/routers/backtesting.py @@ -1,38 +1,41 @@ -from typing import Dict, Union - from fastapi import APIRouter from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase -from pydantic import BaseModel -from config import CONTROLLERS_MODULE, CONTROLLERS_PATH +from config import settings +from models.backtesting import BacktestingConfig -router = APIRouter(tags=["Market Backtesting"]) +router = APIRouter(tags=["Backtesting"], prefix="/backtesting") candles_factory = CandlesFactory() backtesting_engine = BacktestingEngineBase() -class BacktestingConfig(BaseModel): - start_time: int = 1672542000 # 2023-01-01 00:00:00 - end_time: int = 1672628400 # 2023-01-01 23:59:00 - backtesting_resolution: str = "1m" - trade_cost: float = 0.0006 - config: Union[Dict, str] - - @router.post("/run-backtesting") async def run_backtesting(backtesting_config: BacktestingConfig): + """ + Run a backtesting simulation with the provided configuration. + + Args: + backtesting_config: Configuration for the backtesting including start/end time, + resolution, trade cost, and controller config + + Returns: + Dictionary containing executors, processed data, and results from the backtest + + Raises: + Returns error dictionary if backtesting fails + """ try: if isinstance(backtesting_config.config, str): controller_config = backtesting_engine.get_controller_config_instance_from_yml( config_path=backtesting_config.config, - controllers_conf_dir_path=CONTROLLERS_PATH, - controllers_module=CONTROLLERS_MODULE + controllers_conf_dir_path=settings.app.controllers_path, + controllers_module=settings.app.controllers_module ) else: controller_config = backtesting_engine.get_controller_config_instance_from_dict( config_data=backtesting_config.config, - controllers_module=CONTROLLERS_MODULE + controllers_module=settings.app.controllers_module ) backtesting_results = await backtesting_engine.run_backtesting( controller_config=controller_config, trade_cost=backtesting_config.trade_cost, diff --git a/routers/bot_orchestration.py b/routers/bot_orchestration.py new file mode 100644 index 00000000..ef6ab761 --- /dev/null +++ b/routers/bot_orchestration.py @@ -0,0 +1,716 @@ +import logging +import os +import asyncio +from datetime import datetime +from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks + +# Create module-specific logger +logger = logging.getLogger(__name__) + +from models import StartBotAction, StopBotAction, V2ScriptDeployment, V2ControllerDeployment +from services.bots_orchestrator import BotsOrchestrator +from services.docker_service import DockerService +from deps import get_bots_orchestrator, get_docker_service, get_bot_archiver, get_database_manager +from utils.file_system import fs_util +from utils.bot_archiver import BotArchiver +from database import AsyncDatabaseManager, BotRunRepository + +router = APIRouter(tags=["Bot Orchestration"], prefix="/bot-orchestration") + + +@router.get("/status") +def get_active_bots_status(bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get the status of all active bots. + + Args: + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with status and data containing all active bot statuses + """ + return {"status": "success", "data": bots_manager.get_all_bots_status()} + + +@router.get("/mqtt") +def get_mqtt_status(bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get MQTT connection status and discovered bots. + + Args: + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with MQTT connection status, discovered bots, and broker information + """ + mqtt_connected = bots_manager.mqtt_manager.is_connected + discovered_bots = bots_manager.mqtt_manager.get_discovered_bots() + active_bots = list(bots_manager.active_bots.keys()) + + # Check client state + client_state = "connected" if bots_manager.mqtt_manager.is_connected else "disconnected" + + return { + "status": "success", + "data": { + "mqtt_connected": mqtt_connected, + "discovered_bots": discovered_bots, + "active_bots": active_bots, + "broker_host": bots_manager.broker_host, + "broker_port": bots_manager.broker_port, + "broker_username": bots_manager.broker_username, + "client_state": client_state + } + } + + +@router.get("/{bot_name}/status") +def get_bot_status(bot_name: str, bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get the status of a specific bot. + + Args: + bot_name: Name of the bot to get status for + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with bot status information + + Raises: + HTTPException: 404 if bot not found + """ + response = bots_manager.get_bot_status(bot_name) + if not response: + raise HTTPException(status_code=404, detail="Bot not found") + return { + "status": "success", + "data": response + } + + +@router.get("/{bot_name}/history") +async def get_bot_history( + bot_name: str, + days: int = 0, + verbose: bool = False, + precision: int = None, + timeout: float = 30.0, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) +): + """ + Get trading history for a bot with optional parameters. + + Args: + bot_name: Name of the bot to get history for + days: Number of days of history to retrieve (0 for all) + verbose: Whether to include verbose output + precision: Decimal precision for numerical values + timeout: Timeout in seconds for the operation + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with bot trading history + """ + response = await bots_manager.get_bot_history( + bot_name, + days=days, + verbose=verbose, + precision=precision, + timeout=timeout + ) + return {"status": "success", "response": response} + + +@router.post("/start-bot") +async def start_bot( + action: StartBotAction, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Start a bot with the specified configuration. + + Args: + action: StartBotAction containing bot configuration parameters + bots_manager: Bot orchestrator service dependency + db_manager: Database manager dependency + + Returns: + Dictionary with status and response from bot start operation + """ + response = await bots_manager.start_bot(action.bot_name, log_level=action.log_level, script=action.script, + conf=action.conf, async_backend=action.async_backend) + + # Bot run tracking simplified - only track deployment and stop times + + return {"status": "success", "response": response} + + +@router.post("/stop-bot") +async def stop_bot( + action: StopBotAction, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Stop a bot with the specified configuration. + + Args: + action: StopBotAction containing bot stop parameters + bots_manager: Bot orchestrator service dependency + db_manager: Database manager dependency + + Returns: + Dictionary with status and response from bot stop operation + """ + response = await bots_manager.stop_bot(action.bot_name, skip_order_cancellation=action.skip_order_cancellation, + async_backend=action.async_backend) + + # Update bot run status to STOPPED if stop was successful + if response.get("success"): + try: + # Try to get bot status for final status data + final_status = bots_manager.get_bot_status(action.bot_name) + + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + action.bot_name, + final_status=final_status + ) + logger.info(f"Updated bot run status to STOPPED for {action.bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run status: {e}") + # Don't fail the stop operation if bot run update fails + + return {"status": "success", "response": response} + + +@router.get("/bot-runs") +async def get_bot_runs( + bot_name: str = None, + account_name: str = None, + strategy_type: str = None, + strategy_name: str = None, + run_status: str = None, + deployment_status: str = None, + limit: int = 100, + offset: int = 0, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get bot runs with optional filtering. + + Args: + bot_name: Filter by bot name + account_name: Filter by account name + strategy_type: Filter by strategy type (script or controller) + strategy_name: Filter by strategy name + run_status: Filter by run status (CREATED, RUNNING, STOPPED, ERROR) + deployment_status: Filter by deployment status (DEPLOYED, FAILED, ARCHIVED) + limit: Maximum number of results to return + offset: Number of results to skip + db_manager: Database manager dependency + + Returns: + List of bot runs with their details + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_runs = await bot_run_repo.get_bot_runs( + bot_name=bot_name, + account_name=account_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + run_status=run_status, + deployment_status=deployment_status, + limit=limit, + offset=offset + ) + + # Convert bot runs to dictionaries for JSON serialization + runs_data = [] + for run in bot_runs: + run_dict = { + "id": run.id, + "bot_name": run.bot_name, + "instance_name": run.instance_name, + "deployed_at": run.deployed_at.isoformat() if run.deployed_at else None, + "stopped_at": run.stopped_at.isoformat() if run.stopped_at else None, + "strategy_type": run.strategy_type, + "strategy_name": run.strategy_name, + "config_name": run.config_name, + "account_name": run.account_name, + "image_version": run.image_version, + "deployment_status": run.deployment_status, + "run_status": run.run_status, + "deployment_config": run.deployment_config, + "final_status": run.final_status, + "error_message": run.error_message + } + runs_data.append(run_dict) + + return { + "status": "success", + "data": runs_data, + "total": len(runs_data), + "limit": limit, + "offset": offset + } + except Exception as e: + logger.error(f"Failed to get bot runs: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/bot-runs/{bot_run_id}") +async def get_bot_run_by_id( + bot_run_id: int, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get a specific bot run by ID. + + Args: + bot_run_id: ID of the bot run + db_manager: Database manager dependency + + Returns: + Bot run details + + Raises: + HTTPException: 404 if bot run not found + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_run = await bot_run_repo.get_bot_run_by_id(bot_run_id) + + if not bot_run: + raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") + + run_dict = { + "id": bot_run.id, + "bot_name": bot_run.bot_name, + "instance_name": bot_run.instance_name, + "deployed_at": bot_run.deployed_at.isoformat() if bot_run.deployed_at else None, + "stopped_at": bot_run.stopped_at.isoformat() if bot_run.stopped_at else None, + "strategy_type": bot_run.strategy_type, + "strategy_name": bot_run.strategy_name, + "config_name": bot_run.config_name, + "account_name": bot_run.account_name, + "image_version": bot_run.image_version, + "deployment_status": bot_run.deployment_status, + "run_status": bot_run.run_status, + "deployment_config": bot_run.deployment_config, + "final_status": bot_run.final_status, + "error_message": bot_run.error_message + } + + return {"status": "success", "data": run_dict} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get bot run {bot_run_id}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/bot-runs/stats") +async def get_bot_run_stats( + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get statistics about bot runs. + + Args: + db_manager: Database manager dependency + + Returns: + Bot run statistics + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + stats = await bot_run_repo.get_bot_run_stats() + + return {"status": "success", "data": stats} + except Exception as e: + logger.error(f"Failed to get bot run stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +async def _background_stop_and_archive( + bot_name: str, + container_name: str, + bot_name_for_orchestrator: str, + skip_order_cancellation: bool, + archive_locally: bool, + s3_bucket: str, + bots_manager: BotsOrchestrator, + docker_manager: DockerService, + bot_archiver: BotArchiver, + db_manager: AsyncDatabaseManager +): + """Background task to handle the stop and archive process""" + try: + logger.info(f"Starting background stop-and-archive for {bot_name}") + + # Step 1: Capture bot final status before stopping (while bot is still running) + logger.info(f"Capturing final status for {bot_name_for_orchestrator}") + final_status = None + try: + final_status = bots_manager.get_bot_status(bot_name_for_orchestrator) + logger.info(f"Captured final status for {bot_name_for_orchestrator}: {final_status}") + except Exception as e: + logger.warning(f"Failed to capture final status for {bot_name_for_orchestrator}: {e}") + + # Step 2: Update bot run with stopped_at timestamp and final status before stopping + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + final_status=final_status + ) + logger.info(f"Updated bot run with stopped_at timestamp and final status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with stopped status: {e}") + # Continue with stop process even if database update fails + + # Step 3: Mark the bot as stopping, and stop the bot trading process + bots_manager.set_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Stopping bot trading process for {bot_name_for_orchestrator}") + stop_response = await bots_manager.stop_bot( + bot_name_for_orchestrator, + skip_order_cancellation=skip_order_cancellation, + async_backend=True # Always use async for background tasks + ) + + if not stop_response or not stop_response.get("success", False): + error_msg = stop_response.get('error', 'Unknown error') if stop_response else 'No response from bot orchestrator' + logger.error(f"Failed to stop bot process: {error_msg}") + return + + # Step 4: Wait for graceful shutdown (15 seconds as requested) + logger.info(f"Waiting 15 seconds for bot {bot_name} to gracefully shutdown") + await asyncio.sleep(15) + + # Step 5: Stop the container with monitoring + max_retries = 10 + retry_interval = 2 + container_stopped = False + + for i in range(max_retries): + logger.info(f"Attempting to stop container {container_name} (attempt {i+1}/{max_retries})") + docker_manager.stop_container(container_name) + + # Check if container is already stopped + container_status = docker_manager.get_container_status(container_name) + if container_status.get("state", {}).get("status") == "exited": + container_stopped = True + logger.info(f"Container {container_name} is already stopped") + break + + await asyncio.sleep(retry_interval) + + if not container_stopped: + logger.error(f"Failed to stop container {container_name} after {max_retries} attempts") + return + + # Step 6: Archive the bot data + instance_dir = os.path.join('bots', 'instances', container_name) + logger.info(f"Archiving bot data from {instance_dir}") + + try: + if archive_locally: + bot_archiver.archive_locally(container_name, instance_dir) + else: + bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) + logger.info(f"Successfully archived bot data for {container_name}") + except Exception as e: + logger.error(f"Archive failed: {str(e)}") + # Continue with removal even if archive fails + + # Step 7: Remove the container + logging.info(f"Removing container {container_name}") + remove_response = docker_manager.remove_container(container_name, force=False) + + if not remove_response.get("success"): + # If graceful remove fails, try force remove + logging.warning("Graceful container removal failed, attempting force removal") + remove_response = docker_manager.remove_container(container_name, force=True) + + if remove_response.get("success"): + logging.info(f"Successfully completed stop-and-archive for bot {bot_name}") + + # Step 8: Update bot run deployment status to ARCHIVED + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_archived(bot_name) + logger.info(f"Updated bot run deployment status to ARCHIVED for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run to archived: {e}") + else: + logging.error(f"Failed to remove container {container_name}") + + # Update bot run with error status (but keep stopped_at timestamp from earlier) + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message="Failed to remove container during archive process" + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with error: {e}") + + except Exception as e: + logging.error(f"Error in background stop-and-archive for {bot_name}: {str(e)}") + + # Update bot run with error status + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message=str(e) + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as db_error: + logger.error(f"Failed to update bot run with error: {db_error}") + finally: + # Always clear the stopping status when the background task completes + bots_manager.clear_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Cleared stopping status for bot {bot_name}") + + # Remove bot from active_bots and clear all MQTT data + if bot_name_for_orchestrator in bots_manager.active_bots: + bots_manager.mqtt_manager.clear_bot_data(bot_name_for_orchestrator) + del bots_manager.active_bots[bot_name_for_orchestrator] + logger.info(f"Removed bot {bot_name_for_orchestrator} from active_bots and cleared MQTT data") + + +@router.post("/stop-and-archive-bot/{bot_name}") +async def stop_and_archive_bot( + bot_name: str, + background_tasks: BackgroundTasks, + skip_order_cancellation: bool = True, + archive_locally: bool = True, + s3_bucket: str = None, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + docker_manager: DockerService = Depends(get_docker_service), + bot_archiver: BotArchiver = Depends(get_bot_archiver), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Gracefully stop a bot and archive its data in the background. + This initiates a background task that will: + 1. Stop the bot trading process via MQTT + 2. Wait 15 seconds for graceful shutdown + 3. Monitor and stop the Docker container + 4. Archive the bot data (locally or to S3) + 5. Remove the container + + Returns immediately with a success message while the process continues in the background. + """ + try: + # Step 1: Normalize bot name and container name + # Container name is now the same as bot name (no prefix added) + actual_bot_name = bot_name + container_name = bot_name + + logging.info(f"Normalized bot_name: {actual_bot_name}, container_name: {container_name}") + + # Step 2: Validate bot exists in active bots + active_bots = list(bots_manager.active_bots.keys()) + + # Check if bot exists in active bots (could be stored as either format) + bot_found = (actual_bot_name in active_bots) or (container_name in active_bots) + + if not bot_found: + return { + "status": "error", + "message": f"Bot '{actual_bot_name}' not found in active bots. Active bots: {active_bots}. Cannot perform graceful shutdown.", + "details": { + "input_name": bot_name, + "actual_bot_name": actual_bot_name, + "container_name": container_name, + "active_bots": active_bots, + "reason": "Bot must be actively managed via MQTT for graceful shutdown" + } + } + + # Use the format that's actually stored in active bots + bot_name_for_orchestrator = container_name if container_name in active_bots else actual_bot_name + + # Add the background task + background_tasks.add_task( + _background_stop_and_archive, + bot_name=actual_bot_name, + container_name=container_name, + bot_name_for_orchestrator=bot_name_for_orchestrator, + skip_order_cancellation=skip_order_cancellation, + archive_locally=archive_locally, + s3_bucket=s3_bucket, + bots_manager=bots_manager, + docker_manager=docker_manager, + bot_archiver=bot_archiver, + db_manager=db_manager + ) + + return { + "status": "success", + "message": f"Stop and archive process started for bot {actual_bot_name}", + "details": { + "input_name": bot_name, + "actual_bot_name": actual_bot_name, + "container_name": container_name, + "process": "The bot will be gracefully stopped, archived, and removed in the background. This process typically takes 20-30 seconds." + } + } + + except Exception as e: + logging.error(f"Error initiating stop_and_archive_bot for {bot_name}: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/deploy-v2-script") +async def deploy_v2_script( + config: V2ScriptDeployment, + docker_manager: DockerService = Depends(get_docker_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Creates and autostart a v2 script with a configuration if present. + + Args: + config: Configuration for the new Hummingbot instance + docker_manager: Docker service dependency + db_manager: Database manager dependency + + Returns: + Dictionary with creation response and instance details + """ + logging.info(f"Creating hummingbot instance with config: {config}") + response = docker_manager.create_hummingbot_instance(config) + + # Track bot run if deployment was successful + if response.get("success"): + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.create_bot_run( + bot_name=config.instance_name, + instance_name=config.instance_name, + strategy_type="script", + strategy_name=config.script or "unknown", + account_name=config.credentials_profile, + config_name=config.script_config, + image_version=config.image, + deployment_config=config.dict() + ) + logger.info(f"Created bot run record for {config.instance_name}") + except Exception as e: + logger.error(f"Failed to create bot run record: {e}") + # Don't fail the deployment if bot run creation fails + + return response + + +@router.post("/deploy-v2-controllers") +async def deploy_v2_controllers( + deployment: V2ControllerDeployment, + docker_manager: DockerService = Depends(get_docker_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Deploy a V2 strategy with controllers by generating the script config and creating the instance. + This endpoint simplifies the deployment process for V2 controller strategies. + + Args: + deployment: V2ControllerDeployment configuration + docker_manager: Docker service dependency + + Returns: + Dictionary with deployment response and generated configuration details + + Raises: + HTTPException: 500 if deployment fails + """ + try: + # Generate unique script config filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + script_config_filename = f"{deployment.instance_name}-{timestamp}.yml" + + # Ensure controller config names have .yml extension + controllers_with_extension = [] + for controller in deployment.controllers_config: + if not controller.endswith('.yml'): + controllers_with_extension.append(f"{controller}.yml") + else: + controllers_with_extension.append(controller) + + # Create the script config content + script_config_content = { + "script_file_name": "v2_with_controllers.py", + "candles_config": [], + "markets": {}, + "controllers_config": controllers_with_extension, + } + + # Add optional drawdown parameters if provided + if deployment.max_global_drawdown_quote is not None: + script_config_content["max_global_drawdown_quote"] = deployment.max_global_drawdown_quote + if deployment.max_controller_drawdown_quote is not None: + script_config_content["max_controller_drawdown_quote"] = deployment.max_controller_drawdown_quote + + # Save the script config to the scripts directory + scripts_dir = os.path.join("conf", "scripts") + + script_config_path = os.path.join(scripts_dir, script_config_filename) + fs_util.dump_dict_to_yaml(script_config_path, script_config_content) + + logging.info(f"Generated script config: {script_config_filename} with content: {script_config_content}") + + # Create the V2ScriptDeployment with the generated script config + instance_config = V2ScriptDeployment( + instance_name=deployment.instance_name, + credentials_profile=deployment.credentials_profile, + image=deployment.image, + script="v2_with_controllers.py", + script_config=script_config_filename + ) + + # Deploy the instance using the existing method + response = docker_manager.create_hummingbot_instance(instance_config) + + if response.get("success"): + response["script_config_generated"] = script_config_filename + response["controllers_deployed"] = deployment.controllers_config + + # Track bot run if deployment was successful + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.create_bot_run( + bot_name=deployment.instance_name, + instance_name=deployment.instance_name, + strategy_type="controller", + strategy_name="v2_with_controllers", + account_name=deployment.credentials_profile, + config_name=script_config_filename, + image_version=deployment.image, + deployment_config=deployment.dict() + ) + logger.info(f"Created bot run record for controller deployment {deployment.instance_name}") + except Exception as e: + logger.error(f"Failed to create bot run record: {e}") + # Don't fail the deployment if bot run creation fails + + return response + + except Exception as e: + logging.error(f"Error deploying V2 controllers: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/routers/connectors.py b/routers/connectors.py new file mode 100644 index 00000000..fbc0d40a --- /dev/null +++ b/routers/connectors.py @@ -0,0 +1,115 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Request, HTTPException, Query +from hummingbot.client.settings import AllConnectorSettings + +from services.accounts_service import AccountsService +from services.market_data_feed_manager import MarketDataFeedManager +from deps import get_accounts_service + +router = APIRouter(tags=["Connectors"], prefix="/connectors") + + +@router.get("/", response_model=List[str]) +async def available_connectors(): + """ + Get a list of all available connectors. + + Returns: + List of connector names supported by the system + """ + return list(AllConnectorSettings.get_connector_settings().keys()) + + +@router.get("/{connector_name}/config-map", response_model=List[str]) +async def get_connector_config_map(connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get configuration fields required for a specific connector. + + Args: + connector_name: Name of the connector to get config map for + + Returns: + List of configuration field names required for the connector + """ + return accounts_service.get_connector_config_map(connector_name) + + +@router.get("/{connector_name}/trading-rules") +async def get_trading_rules( + request: Request, + connector_name: str, + trading_pairs: Optional[List[str]] = Query(default=None, description="Filter by specific trading pairs") +): + """ + Get trading rules for a connector, optionally filtered by trading pairs. + + This endpoint uses the MarketDataFeedManager to access non-trading connector instances, + which means no authentication or account setup is required. + + Args: + request: FastAPI request object + connector_name: Name of the connector (e.g., 'binance', 'binance_perpetual') + trading_pairs: Optional list of trading pairs to filter by (e.g., ['BTC-USDT', 'ETH-USDT']) + + Returns: + Dictionary mapping trading pairs to their trading rules + + Raises: + HTTPException: 404 if connector not found, 500 for other errors + """ + try: + market_data_feed_manager: MarketDataFeedManager = request.app.state.market_data_feed_manager + + # Get trading rules (filtered by trading pairs if provided) + rules = await market_data_feed_manager.get_trading_rules(connector_name, trading_pairs) + + if "error" in rules: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found or error: {rules['error']}") + + return rules + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error retrieving trading rules: {str(e)}") + + +@router.get("/{connector_name}/order-types") +async def get_supported_order_types(request: Request, connector_name: str): + """ + Get order types supported by a specific connector. + + This endpoint uses the MarketDataFeedManager to access non-trading connector instances, + which means no authentication or account setup is required. + + Args: + request: FastAPI request object + connector_name: Name of the connector (e.g., 'binance', 'binance_perpetual') + + Returns: + List of supported order types (LIMIT, MARKET, LIMIT_MAKER) + + Raises: + HTTPException: 404 if connector not found, 500 for other errors + """ + try: + market_data_feed_manager: MarketDataFeedManager = request.app.state.market_data_feed_manager + + # Access connector through MarketDataProvider's _rate_sources + connector_instance = market_data_feed_manager.market_data_provider._rate_sources.get(connector_name) + + if not connector_instance: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found") + + # Get supported order types + if hasattr(connector_instance, 'supported_order_types'): + order_types = [order_type.name for order_type in connector_instance.supported_order_types()] + return {"connector": connector_name, "supported_order_types": order_types} + else: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' does not support order types query") + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error retrieving order types: {str(e)}") \ No newline at end of file diff --git a/routers/controllers.py b/routers/controllers.py new file mode 100644 index 00000000..c94e57a9 --- /dev/null +++ b/routers/controllers.py @@ -0,0 +1,316 @@ +import json +import yaml +from typing import Dict, List + +from fastapi import APIRouter, HTTPException +from starlette import status + +from models import Controller, ControllerType +from utils.file_system import fs_util + +router = APIRouter(tags=["Controllers"], prefix="/controllers") + + +@router.get("/", response_model=Dict[str, List[str]]) +async def list_controllers(): + """ + List all controllers organized by type. + + Returns: + Dictionary mapping controller types to lists of controller names + """ + result = {} + for controller_type in ControllerType: + try: + files = fs_util.list_files(f'controllers/{controller_type.value}') + result[controller_type.value] = [ + f.replace('.py', '') for f in files + if f.endswith('.py') and f != "__init__.py" + ] + except FileNotFoundError: + result[controller_type.value] = [] + return result + + +# Controller Configuration endpoints (must come before controller type routes) +@router.get("/configs/", response_model=List[Dict]) +async def list_controller_configs(): + """ + List all controller configurations with metadata. + + Returns: + List of controller configuration objects with name, controller_name, controller_type, and other metadata + """ + try: + config_files = [f for f in fs_util.list_files('conf/controllers') if f.endswith('.yml')] + configs = [] + + for config_file in config_files: + config_name = config_file.replace('.yml', '') + try: + config = fs_util.read_yaml_file(f"conf/controllers/{config_file}") + configs.append({ + "config_name": config_name, + "controller_name": config.get("controller_name", "unknown"), + "controller_type": config.get("controller_type", "unknown"), + "connector_name": config.get("connector_name", "unknown"), + "trading_pair": config.get("trading_pair", "unknown"), + "total_amount_quote": config.get("total_amount_quote", 0) + }) + except Exception as e: + # If config is malformed, still include it with basic info + configs.append({ + "config_name": config_name, + "controller_name": "error", + "controller_type": "error", + "error": str(e) + }) + + return configs + except FileNotFoundError: + return [] + + +@router.get("/configs/{config_name}", response_model=Dict) +async def get_controller_config(config_name: str): + """ + Get controller configuration by config name. + + Args: + config_name: Name of the configuration file to retrieve + + Returns: + Dictionary with controller configuration + + Raises: + HTTPException: 404 if configuration not found + """ + try: + config = fs_util.read_yaml_file(f"conf/controllers/{config_name}.yml") + return config + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.post("/configs/{config_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_controller_config(config_name: str, config: Dict): + """ + Create or update controller configuration. + + Args: + config_name: Name of the configuration file + config: Configuration dictionary to save + + Returns: + Success message when configuration is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + yaml_content = yaml.dump(config, default_flow_style=False) + fs_util.add_file('conf/controllers', f"{config_name}.yml", yaml_content, override=True) + return {"message": f"Configuration '{config_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/configs/{config_name}") +async def delete_controller_config(config_name: str): + """ + Delete controller configuration. + + Args: + config_name: Name of the configuration file to delete + + Returns: + Success message when configuration is deleted + + Raises: + HTTPException: 404 if configuration not found + """ + try: + fs_util.delete_file('conf/controllers', f"{config_name}.yml") + return {"message": f"Configuration '{config_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.get("/{controller_type}/{controller_name}", response_model=Dict[str, str]) +async def get_controller(controller_type: ControllerType, controller_name: str): + """ + Get controller content by type and name. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller + + Returns: + Dictionary with controller name, type, and content + + Raises: + HTTPException: 404 if controller not found + """ + try: + content = fs_util.read_file(f"controllers/{controller_type.value}/{controller_name}.py") + return { + "name": controller_name, + "type": controller_type.value, + "content": content + } + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Controller '{controller_name}' not found in '{controller_type.value}'" + ) + + +@router.post("/{controller_type}/{controller_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_controller(controller_type: ControllerType, controller_name: str, controller: Controller): + """ + Create or update a controller. + + Args: + controller_type: Type of controller to create/update + controller_name: Name of the controller (from URL path) + controller: Controller object with content (and optional type for validation) + + Returns: + Success message when controller is saved + + Raises: + HTTPException: 400 if controller type mismatch or save error + """ + # If type is provided in body, validate it matches URL + if controller.type is not None and controller.type != controller_type: + raise HTTPException( + status_code=400, + detail=f"Controller type mismatch: URL has '{controller_type}', body has '{controller.type}'" + ) + + try: + fs_util.add_file( + f'controllers/{controller_type.value}', + f"{controller_name}.py", + controller.content, + override=True + ) + return {"message": f"Controller '{controller_name}' saved successfully in '{controller_type.value}'"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/{controller_type}/{controller_name}") +async def delete_controller(controller_type: ControllerType, controller_name: str): + """ + Delete a controller. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller to delete + + Returns: + Success message when controller is deleted + + Raises: + HTTPException: 404 if controller not found + """ + try: + fs_util.delete_file(f'controllers/{controller_type.value}', f"{controller_name}.py") + return {"message": f"Controller '{controller_name}' deleted successfully from '{controller_type.value}'"} + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Controller '{controller_name}' not found in '{controller_type.value}'" + ) + + +@router.get("/{controller_type}/{controller_name}/config/template", response_model=Dict) +async def get_controller_config_template(controller_type: ControllerType, controller_name: str): + """ + Get controller configuration template with default values. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller + + Returns: + Dictionary with configuration template and default values + + Raises: + HTTPException: 404 if controller configuration class not found + """ + config_class = fs_util.load_controller_config_class(controller_type.value, controller_name) + if config_class is None: + raise HTTPException( + status_code=404, + detail=f"Controller configuration class for '{controller_name}' not found" + ) + + # Extract fields and default values + config_fields = {name: field.default for name, field in config_class.model_fields.items()} + return json.loads(json.dumps(config_fields, default=str)) + + + + +# Bot-specific controller config endpoints +@router.get("/bots/{bot_name}/configs", response_model=List[Dict]) +async def get_bot_controller_configs(bot_name: str): + """ + Get all controller configurations for a specific bot. + + Args: + bot_name: Name of the bot to get configurations for + + Returns: + List of controller configurations for the bot + + Raises: + HTTPException: 404 if bot not found + """ + bots_config_path = f"instances/{bot_name}/conf/controllers" + if not fs_util.path_exists(bots_config_path): + raise HTTPException(status_code=404, detail=f"Bot '{bot_name}' not found") + + configs = [] + for controller_file in fs_util.list_files(bots_config_path): + if controller_file.endswith('.yml'): + config = fs_util.read_yaml_file(f"{bots_config_path}/{controller_file}") + config['_config_name'] = controller_file.replace('.yml', '') + configs.append(config) + return configs + + +@router.post("/bots/{bot_name}/{controller_name}/config") +async def update_bot_controller_config(bot_name: str, controller_name: str, config: Dict): + """ + Update controller configuration for a specific bot. + + Args: + bot_name: Name of the bot + controller_name: Name of the controller to update + config: Configuration dictionary to update with + + Returns: + Success message when configuration is updated + + Raises: + HTTPException: 404 if bot or controller not found, 400 if update error + """ + bots_config_path = f"instances/{bot_name}/conf/controllers" + if not fs_util.path_exists(bots_config_path): + raise HTTPException(status_code=404, detail=f"Bot '{bot_name}' not found") + + try: + current_config = fs_util.read_yaml_file(f"{bots_config_path}/{controller_name}.yml") + current_config.update(config) + fs_util.dump_dict_to_yaml(f"{bots_config_path}/{controller_name}.yml", current_config) + return {"message": f"Controller configuration for bot '{bot_name}' updated successfully"} + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Controller configuration '{controller_name}' not found for bot '{bot_name}'" + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file diff --git a/routers/docker.py b/routers/docker.py new file mode 100644 index 00000000..7b0f8287 --- /dev/null +++ b/routers/docker.py @@ -0,0 +1,192 @@ +import os + +from fastapi import APIRouter, HTTPException, Depends + +from models import DockerImage +from utils.bot_archiver import BotArchiver +from services.docker_service import DockerService +from deps import get_docker_service, get_bot_archiver + +router = APIRouter(tags=["Docker"], prefix="/docker") + + +@router.get("/running") +async def is_docker_running(docker_service: DockerService = Depends(get_docker_service)): + """ + Check if Docker daemon is running. + + Args: + docker_service: Docker service dependency + + Returns: + Dictionary indicating if Docker is running + """ + return docker_service.is_docker_running() + + +@router.get("/available-images/") +async def available_images(image_name: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get available Docker images matching the specified name. + + Args: + image_name: Name pattern to search for in image tags + docker_service: Docker service dependency + + Returns: + Dictionary with list of available image tags + """ + available_images = docker_service.get_available_images() + if image_name: + return [tag for image in available_images["images"] for tag in image.tags if image_name in tag] + return [tag for tag in available_images["images"]] + + +@router.get("/active-containers") +async def active_containers(name_filter: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get all currently active (running) Docker containers. + + Args: + name_filter: Optional filter to match container names (case-insensitive) + docker_service: Docker service dependency + + Returns: + List of active container information + """ + return docker_service.get_active_containers(name_filter) + + +@router.get("/exited-containers") +async def exited_containers(name_filter: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get all exited (stopped) Docker containers. + + Args: + name_filter: Optional filter to match container names (case-insensitive) + docker_service: Docker service dependency + + Returns: + List of exited container information + """ + return docker_service.get_exited_containers(name_filter) + + +@router.post("/clean-exited-containers") +async def clean_exited_containers(docker_service: DockerService = Depends(get_docker_service)): + """ + Remove all exited Docker containers to free up space. + + Args: + docker_service: Docker service dependency + + Returns: + Response from cleanup operation + """ + return docker_service.clean_exited_containers() + + +@router.post("/remove-container/{container_name}") +async def remove_container(container_name: str, archive_locally: bool = True, s3_bucket: str = None, docker_service: DockerService = Depends(get_docker_service), bot_archiver: BotArchiver = Depends(get_bot_archiver)): + """ + Remove a Hummingbot container and optionally archive its bot data. + + NOTE: This endpoint only works with Hummingbot containers (names starting with 'hummingbot-') + as it archives bot-specific data from the bots/instances directory. + + Args: + container_name: Name of the Hummingbot container to remove + archive_locally: Whether to archive data locally (default: True) + s3_bucket: S3 bucket name for cloud archiving (optional) + docker_service: Docker service dependency + bot_archiver: Bot archiver service dependency + + Returns: + Response from container removal operation + + Raises: + HTTPException: 400 if container is not a Hummingbot container + HTTPException: 500 if archiving fails + """ + # Validate that this is a Hummingbot container + if not container_name.startswith("hummingbot-"): + raise HTTPException( + status_code=400, + detail=f"This endpoint only removes Hummingbot containers. Container '{container_name}' is not a Hummingbot container." + ) + + # Remove the container + response = docker_service.remove_container(container_name) + # Form the instance directory path correctly + instance_dir = os.path.join('bots', 'instances', container_name) + try: + # Archive the data + if archive_locally: + bot_archiver.archive_locally(container_name, instance_dir) + else: + bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return response + + +@router.post("/stop-container/{container_name}") +async def stop_container(container_name: str, docker_service: DockerService = Depends(get_docker_service)): + """ + Stop a running Docker container. + + Args: + container_name: Name of the container to stop + docker_service: Docker service dependency + + Returns: + Response from container stop operation + """ + return docker_service.stop_container(container_name) + + +@router.post("/start-container/{container_name}") +async def start_container(container_name: str, docker_service: DockerService = Depends(get_docker_service)): + """ + Start a stopped Docker container. + + Args: + container_name: Name of the container to start + docker_service: Docker service dependency + + Returns: + Response from container start operation + """ + return docker_service.start_container(container_name) + + +@router.post("/pull-image/") +async def pull_image(image: DockerImage, docker_service: DockerService = Depends(get_docker_service)): + """ + Initiate Docker image pull as background task. + Returns immediately with task status for monitoring. + + Args: + image: DockerImage object containing the image name to pull + docker_service: Docker service dependency + + Returns: + Status of the pull operation initiation + """ + result = docker_service.pull_image_async(image.image_name) + return result + + +@router.get("/pull-status/") +async def get_pull_status(docker_service: DockerService = Depends(get_docker_service)): + """ + Get status of all pull operations. + + Args: + docker_service: Docker service dependency + + Returns: + Dictionary with all pull operations and their statuses + """ + return docker_service.get_all_pull_status() diff --git a/routers/manage_accounts.py b/routers/manage_accounts.py deleted file mode 100644 index 7751f7ce..00000000 --- a/routers/manage_accounts.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, HTTPException -from hummingbot.client.settings import AllConnectorSettings -from starlette import status - -from services.accounts_service import AccountsService -from utils.file_system import FileSystemUtil - -router = APIRouter(tags=["Manage Credentials"]) -file_system = FileSystemUtil(base_path="bots/credentials") -accounts_service = AccountsService() - - -@router.on_event("startup") -async def startup_event(): - accounts_service.start_update_account_state_loop() - - -@router.on_event("shutdown") -async def shutdown_event(): - accounts_service.stop_update_account_state_loop() - - -@router.get("/accounts-state", response_model=Dict[str, Dict[str, List[Dict]]]) -async def get_all_accounts_state(): - return accounts_service.get_accounts_state() - - -@router.get("/account-state-history", response_model=List[Dict]) -async def get_account_state_history(): - """ - Get the historical state of all accounts. - """ - try: - history = accounts_service.load_account_state_history() - return history - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/available-connectors", response_model=List[str]) -async def available_connectors(): - return list(AllConnectorSettings.get_connector_settings().keys()) - - -@router.get("/connector-config-map/{connector_name}", response_model=List[str]) -async def get_connector_config_map(connector_name: str): - return accounts_service.get_connector_config_map(connector_name) - - -@router.get("/all-connectors-config-map", response_model=Dict[str, List[str]]) -async def get_all_connectors_config_map(): - all_config_maps = {} - for connector in list(AllConnectorSettings.get_connector_settings().keys()): - all_config_maps[connector] = accounts_service.get_connector_config_map(connector) - return all_config_maps - - -@router.get("/list-accounts", response_model=List[str]) -async def list_accounts(): - return accounts_service.list_accounts() - - -@router.get("/list-credentials/{account_name}", response_model=List[str]) -async def list_credentials(account_name: str): - try: - return accounts_service.list_credentials(account_name) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/add-account", status_code=status.HTTP_201_CREATED) -async def add_account(account_name: str): - try: - accounts_service.add_account(account_name) - return {"message": "Credential added successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/delete-account") -async def delete_account(account_name: str): - try: - if account_name == "master_account": - raise HTTPException(status_code=400, detail="Cannot delete master account.") - accounts_service.delete_account(account_name) - return {"message": "Credential deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/delete-credential/{account_name}/{connector_name}") -async def delete_credential(account_name: str, connector_name: str): - try: - accounts_service.delete_credentials(account_name, connector_name) - return {"message": "Credential deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/add-connector-keys/{account_name}/{connector_name}", status_code=status.HTTP_201_CREATED) -async def add_connector_keys(account_name: str, connector_name: str, keys: Dict): - try: - await accounts_service.add_connector_keys(account_name, connector_name, keys) - return {"message": "Connector keys added successfully."} - except Exception as e: - accounts_service.delete_credentials(account_name, connector_name) - raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/manage_broker_messages.py b/routers/manage_broker_messages.py deleted file mode 100644 index e31d1f3f..00000000 --- a/routers/manage_broker_messages.py +++ /dev/null @@ -1,64 +0,0 @@ -from fastapi import APIRouter, HTTPException - -from config import BROKER_HOST, BROKER_PASSWORD, BROKER_PORT, BROKER_USERNAME -from models import ImportStrategyAction, StartBotAction, StopBotAction -from services.bots_orchestrator import BotsManager - -# Initialize the scheduler -router = APIRouter(tags=["Manage Broker Messages"]) -bots_manager = BotsManager(broker_host=BROKER_HOST, broker_port=BROKER_PORT, broker_username=BROKER_USERNAME, - broker_password=BROKER_PASSWORD) - - -@router.on_event("startup") -async def startup_event(): - bots_manager.start_update_active_bots_loop() - - -@router.on_event("shutdown") -async def shutdown_event(): - # Shutdown the scheduler on application exit - bots_manager.stop_update_active_bots_loop() - - -@router.get("/get-active-bots-status") -def get_active_bots_status(): - """Returns the cached status of all active bots.""" - return {"status": "success", "data": bots_manager.get_all_bots_status()} - - -@router.get("/get-bot-status/{bot_name}") -def get_bot_status(bot_name: str): - response = bots_manager.get_bot_status(bot_name) - if not response: - raise HTTPException(status_code=404, detail="Bot not found") - return { - "status": "success", - "data": response - } - - -@router.get("/get-bot-history/{bot_name}") -def get_bot_history(bot_name: str): - response = bots_manager.get_bot_history(bot_name) - return {"status": "success", "response": response} - - -@router.post("/start-bot") -def start_bot(action: StartBotAction): - response = bots_manager.start_bot(action.bot_name, log_level=action.log_level, script=action.script, - conf=action.conf, async_backend=action.async_backend) - return {"status": "success", "response": response} - - -@router.post("/stop-bot") -def stop_bot(action: StopBotAction): - response = bots_manager.stop_bot(action.bot_name, skip_order_cancellation=action.skip_order_cancellation, - async_backend=action.async_backend) - return {"status": "success", "response": response} - - -@router.post("/import-strategy") -def import_strategy(action: ImportStrategyAction): - response = bots_manager.import_strategy_for_bot(action.bot_name, action.strategy) - return {"status": "success", "response": response} diff --git a/routers/manage_databases.py b/routers/manage_databases.py deleted file mode 100644 index ae90dce8..00000000 --- a/routers/manage_databases.py +++ /dev/null @@ -1,100 +0,0 @@ -import json -import time - -from typing import List, Dict, Any - -import pandas as pd - -from utils.etl_databases import HummingbotDatabase, ETLPerformance -from fastapi import APIRouter - -from utils.file_system import FileSystemUtil - -router = APIRouter(tags=["Database Management"]) -file_system = FileSystemUtil() - - -@router.post("/list-databases", response_model=List[str]) -async def list_databases(): - return file_system.list_databases() - - -@router.post("/read-databases", response_model=List[Dict[str, Any]]) -async def read_databases(db_paths: List[str] = None): - dbs = [] - for db_path in db_paths: - db = HummingbotDatabase(db_path) - try: - db_content = { - "db_name": db.db_name, - "db_path": db.db_path, - "healthy": db.status["general_status"], - "status": db.status, - "tables": { - "orders": json.dumps(db.get_orders().to_dict()), - "trade_fill": json.dumps(db.get_trade_fills().to_dict()), - "executors": json.dumps(db.get_executors_data().to_dict()), - "order_status": json.dumps(db.get_order_status().to_dict()), - "controllers": json.dumps(db.get_controllers_data().to_dict()) - } - } - except Exception as e: - print(f"Error reading database {db_path}: {str(e)}") - db_content = { - "db_name": "", - "db_path": db_path, - "healthy": False, - "status": db.status, - "tables": {} - } - dbs.append(db_content) - return dbs - - -@router.post("/create-checkpoint", response_model=Dict[str, Any]) -async def create_checkpoint(db_paths: List[str]): - try: - dbs = await read_databases(db_paths) - - healthy_dbs = [db for db in dbs if db["healthy"]] - - table_names = ["trade_fill", "orders", "order_status", "executors", "controllers"] - tables_dict = {name: pd.DataFrame() for name in table_names} - - for db in healthy_dbs: - for table_name in table_names: - new_data = pd.DataFrame(json.loads(db["tables"][table_name])) - new_data["db_path"] = db["db_path"] - new_data["db_name"] = db["db_name"] - tables_dict[table_name] = pd.concat([tables_dict[table_name], new_data]) - - etl = ETLPerformance(db_path=f"bots/data/checkpoint_{str(int(time.time()))}.sqlite") - etl.create_tables() - etl.insert_data(tables_dict) - return {"message": "Checkpoint created successfully."} - except Exception as e: - return {"message": f"Error: {str(e)}"} - - -@router.post("/list-checkpoints", response_model=List[str]) -async def list_checkpoints(full_path: bool): - return file_system.list_checkpoints(full_path) - - -@router.post("/load-checkpoint") -async def load_checkpoint(checkpoint_path: str): - try: - etl = ETLPerformance(checkpoint_path) - executor = etl.load_executors() - order = etl.load_orders() - trade_fill = etl.load_trade_fill() - controllers = etl.load_controllers() - checkpoint_data = { - "executors": json.dumps(executor.to_dict()), - "orders": json.dumps(order.to_dict()), - "trade_fill": json.dumps(trade_fill.to_dict()), - "controllers": json.dumps(controllers.to_dict()) - } - return checkpoint_data - except Exception as e: - return {"message": f"Error: {str(e)}"} \ No newline at end of file diff --git a/routers/manage_docker.py b/routers/manage_docker.py deleted file mode 100644 index 9769cff3..00000000 --- a/routers/manage_docker.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import os - -from fastapi import APIRouter, HTTPException - -from models import HummingbotInstanceConfig, ImageName -from services.bot_archiver import BotArchiver -from services.docker_service import DockerManager - -router = APIRouter(tags=["Docker Management"]) -docker_manager = DockerManager() -bot_archiver = BotArchiver(os.environ.get("AWS_API_KEY"), os.environ.get("AWS_SECRET_KEY"), - os.environ.get("S3_DEFAULT_BUCKET_NAME")) - - -@router.get("/is-docker-running") -async def is_docker_running(): - return {"is_docker_running": docker_manager.is_docker_running()} - - -@router.get("/available-images/{image_name}") -async def available_images(image_name: str): - available_images = docker_manager.get_available_images() - image_tags = [tag for image in available_images["images"] for tag in image.tags if image_name in tag] - return {"available_images": image_tags} - - -@router.get("/active-containers") -async def active_containers(): - return docker_manager.get_active_containers() - - -@router.get("/exited-containers") -async def exited_containers(): - return docker_manager.get_exited_containers() - - -@router.post("/clean-exited-containers") -async def clean_exited_containers(): - return docker_manager.clean_exited_containers() - - -@router.post("/remove-container/{container_name}") -async def remove_container(container_name: str, archive_locally: bool = True, s3_bucket: str = None): - # Remove the container - response = docker_manager.remove_container(container_name) - # Form the instance directory path correctly - instance_dir = os.path.join('bots', 'instances', container_name) - try: - # Archive the data - if archive_locally: - bot_archiver.archive_locally(container_name, instance_dir) - else: - bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - return response - - -@router.post("/stop-container/{container_name}") -async def stop_container(container_name: str): - return docker_manager.stop_container(container_name) - - -@router.post("/start-container/{container_name}") -async def start_container(container_name: str): - return docker_manager.start_container(container_name) - - -@router.post("/create-hummingbot-instance") -async def create_hummingbot_instance(config: HummingbotInstanceConfig): - logging.info(f"Creating hummingbot instance with config: {config}") - response = docker_manager.create_hummingbot_instance(config) - return response - - -@router.post("/pull-image/") -async def pull_image(image: ImageName): - try: - result = docker_manager.pull_image(image.image_name) - return result - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/manage_files.py b/routers/manage_files.py deleted file mode 100644 index 058b72c7..00000000 --- a/routers/manage_files.py +++ /dev/null @@ -1,209 +0,0 @@ -import json -from typing import Dict, List - -import yaml -from fastapi import APIRouter, File, HTTPException, UploadFile -from starlette import status - -from models import Script, ScriptConfig -from utils.file_system import FileSystemUtil - -router = APIRouter(tags=["Files Management"]) - -file_system = FileSystemUtil() - - -@router.get("/list-scripts", response_model=List[str]) -async def list_scripts(): - return file_system.list_files('scripts') - - -@router.get("/list-scripts-configs", response_model=List[str]) -async def list_scripts_configs(): - return file_system.list_files('conf/scripts') - - -@router.get("/script-config/{script_name}", response_model=dict) -async def get_script_config(script_name: str): - """ - Retrieves the configuration parameters for a given script. - :param script_name: The name of the script. - :return: JSON containing the configuration parameters. - """ - config_class = file_system.load_script_config_class(script_name) - if config_class is None: - raise HTTPException(status_code=404, detail="Script configuration class not found") - - # Extracting fields and default values - config_fields = {field.name: field.default for field in config_class.__fields__.values()} - return json.loads(json.dumps(config_fields, default=str)) # Handling non-serializable types like Decimal - - -@router.get("/list-controllers", response_model=dict) -async def list_controllers(): - directional_trading_controllers = [file for file in file_system.list_files('controllers/directional_trading') if - file != "__init__.py"] - market_making_controllers = [file for file in file_system.list_files('controllers/market_making') if - file != "__init__.py"] - generic_controllers = [file for file in file_system.list_files('controllers/generic') if file != "__init__.py"] - - return {"directional_trading": directional_trading_controllers, - "market_making": market_making_controllers, - "generic": generic_controllers} - -@router.get("/controller-config-pydantic/{controller_type}/{controller_name}", response_model=dict) -async def get_controller_config_pydantic(controller_type: str, controller_name: str): - """ - Retrieves the configuration parameters for a given controller. - :param controller_name: The name of the controller. - :return: JSON containing the configuration parameters. - """ - config_class = file_system.load_controller_config_class(controller_type, controller_name) - if config_class is None: - raise HTTPException(status_code=404, detail="Controller configuration class not found") - - # Extracting fields and default values - config_fields = {name: field.default for name, field in config_class.model_fields.items()} - return json.loads(json.dumps(config_fields, default=str)) - - -@router.get("/list-controllers-configs", response_model=List[str]) -async def list_controllers_configs(): - return file_system.list_files('conf/controllers') - - -@router.get("/controller-config/{controller_name}", response_model=dict) -async def get_controller_config(controller_name: str): - config = file_system.read_yaml_file(f"bots/conf/controllers/{controller_name}.yml") - return config - - -@router.get("/all-controller-configs", response_model=List[dict]) -async def get_all_controller_configs(): - configs = [] - for controller in file_system.list_files('conf/controllers'): - config = file_system.read_yaml_file(f"bots/conf/controllers/{controller}") - configs.append(config) - return configs - - -@router.get("/all-controller-configs/bot/{bot_name}", response_model=List[dict]) -async def get_all_controller_configs_for_bot(bot_name: str): - configs = [] - bots_config_path = f"instances/{bot_name}/conf/controllers" - if not file_system.path_exists(bots_config_path): - raise HTTPException(status_code=400, detail="Bot not found.") - for controller in file_system.list_files(bots_config_path): - config = file_system.read_yaml_file(f"bots/{bots_config_path}/{controller}") - configs.append(config) - return configs - - -@router.post("/update-controller-config/bot/{bot_name}/{controller_id}") -async def update_controller_config(bot_name: str, controller_id: str, config: Dict): - bots_config_path = f"instances/{bot_name}/conf/controllers" - if not file_system.path_exists(bots_config_path): - raise HTTPException(status_code=400, detail="Bot not found.") - current_config = file_system.read_yaml_file(f"bots/{bots_config_path}/{controller_id}.yml") - current_config.update(config) - file_system.dump_dict_to_yaml(f"bots/{bots_config_path}/{controller_id}.yml", current_config) - return {"message": "Controller configuration updated successfully."} - - -@router.post("/add-script", status_code=status.HTTP_201_CREATED) -async def add_script(script: Script, override: bool = False): - try: - file_system.add_file('scripts', script.name + '.py', script.content, override) - return {"message": "Script added successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-script") -async def upload_script(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('scripts', config_file.filename, contents.decode(), override) - return {"message": "Script uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/add-script-config", status_code=status.HTTP_201_CREATED) -async def add_script_config(config: ScriptConfig): - try: - yaml_content = yaml.dump(config.content) - - file_system.add_file('conf/scripts', config.name + '.yml', yaml_content, override=True) - return {"message": "Script configuration uploaded successfully."} - except Exception as e: # Consider more specific exception handling - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-script-config") -async def upload_script_config(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('conf/scripts', config_file.filename, contents.decode(), override) - return {"message": "Script configuration uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/add-controller-config", status_code=status.HTTP_201_CREATED) -async def add_controller_config(config: ScriptConfig): - try: - yaml_content = yaml.dump(config.content) - - file_system.add_file('conf/controllers', config.name + '.yml', yaml_content, override=True) - return {"message": "Controller configuration uploaded successfully."} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-controller-config") -async def upload_controller_config(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('conf/controllers', config_file.filename, contents.decode(), override) - return {"message": "Controller configuration uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/delete-controller-config", status_code=status.HTTP_200_OK) -async def delete_controller_config(config_name: str): - try: - file_system.delete_file('conf/controllers', config_name) - return {"message": f"Controller configuration {config_name} deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/delete-script-config", status_code=status.HTTP_200_OK) -async def delete_script_config(config_name: str): - try: - file_system.delete_file('conf/scripts', config_name) - return {"message": f"Script configuration {config_name} deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/delete-all-controller-configs", status_code=status.HTTP_200_OK) -async def delete_all_controller_configs(): - try: - for file in file_system.list_files('conf/controllers'): - file_system.delete_file('conf/controllers', file) - return {"message": "All controller configurations deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/delete-all-script-configs", status_code=status.HTTP_200_OK) -async def delete_all_script_configs(): - try: - for file in file_system.list_files('conf/scripts'): - file_system.delete_file('conf/scripts', file) - return {"message": "All script configurations deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) diff --git a/routers/manage_market_data.py b/routers/manage_market_data.py deleted file mode 100644 index a5e2bcd7..00000000 --- a/routers/manage_market_data.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio - -from fastapi import APIRouter -from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig, HistoricalCandlesConfig - -router = APIRouter(tags=["Market Data"]) -candles_factory = CandlesFactory() - - -@router.post("/real-time-candles") -async def get_candles(candles_config: CandlesConfig): - try: - candles = candles_factory.get_candle(candles_config) - candles.start() - while not candles.ready: - await asyncio.sleep(1) - df = candles.candles_df - candles.stop() - df.drop_duplicates(subset=["timestamp"], inplace=True) - return df - except Exception as e: - return {"error": str(e)} - - -@router.post("/historical-candles") -async def get_historical_candles(config: HistoricalCandlesConfig): - try: - candles_config = CandlesConfig( - connector=config.connector_name, - trading_pair=config.trading_pair, - interval=config.interval - ) - candles = candles_factory.get_candle(candles_config) - return await candles.get_historical_candles(config=config) - except Exception as e: - return {"error": str(e)} diff --git a/routers/manage_performance.py b/routers/manage_performance.py deleted file mode 100644 index 01bc316b..00000000 --- a/routers/manage_performance.py +++ /dev/null @@ -1,28 +0,0 @@ -from fastapi import APIRouter -from typing import Any, Dict - -from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase - -from utils.etl_databases import PerformanceDataSource - -router = APIRouter(tags=["Market Performance"]) - - -@router.post("/get-performance-results") -async def get_performance_results(payload: Dict[str, Any]): - executors = payload.get("executors") - data_source = PerformanceDataSource(executors) - performance_results = {} - try: - backtesting_engine = BacktestingEngineBase() - executor_info_list = data_source.executor_info_list - performance_results["results"] = backtesting_engine.summarize_results(executor_info_list ) - results = performance_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - return { - "executors": executors, - "results": performance_results["results"], - } - - except Exception as e: - return {"error": str(e)} \ No newline at end of file diff --git a/routers/market_data.py b/routers/market_data.py new file mode 100644 index 00000000..214cb636 --- /dev/null +++ b/routers/market_data.py @@ -0,0 +1,437 @@ +import asyncio +import time + +from fastapi import APIRouter, Request, HTTPException, Depends +from hummingbot.data_feed.candles_feed.data_types import HistoricalCandlesConfig, CandlesConfig +from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory + +from models.market_data import CandlesConfigRequest +from services.market_data_feed_manager import MarketDataFeedManager +from models import ( + PriceRequest, PricesResponse, FundingInfoRequest, FundingInfoResponse, + OrderBookRequest, OrderBookResponse, OrderBookLevel, + VolumeForPriceRequest, PriceForVolumeRequest, QuoteVolumeForPriceRequest, + PriceForQuoteVolumeRequest, VWAPForVolumeRequest, OrderBookQueryResult +) +from deps import get_market_data_feed_manager + +router = APIRouter(tags=["Market Data"], prefix="/market-data") + + +@router.post("/candles") +async def get_candles(request: Request, candles_config: CandlesConfigRequest): + """ + Get real-time candles data for a specific trading pair. + + This endpoint uses the MarketDataProvider to get or create a candles feed that will + automatically start and maintain real-time updates. Subsequent requests with the same + configuration will reuse the existing feed for up-to-date data. + + Args: + request: FastAPI request object + candles_config: Configuration for the candles including connector, trading_pair, interval, and max_records + + Returns: + Real-time candles data or error message + """ + try: + market_data_feed_manager: MarketDataFeedManager = request.app.state.market_data_feed_manager + + # Get or create the candles feed (this will start it automatically and track access time) + candles_cfg = CandlesConfig( + connector=candles_config.connector_name, trading_pair=candles_config.trading_pair, + interval=candles_config.interval, max_records=candles_config.max_records) + candles_feed = market_data_feed_manager.get_candles_feed(candles_cfg) + + # Wait for the candles feed to be ready + while not candles_feed.ready: + await asyncio.sleep(0.1) + + # Get the candles dataframe + df = candles_feed.candles_df + + if df is not None and not df.empty: + # Limit to requested max_records and remove duplicates + df = df.tail(candles_config.max_records) + df = df.drop_duplicates(subset=["timestamp"], keep="last") + # Convert to dict for JSON serialization + return df.to_dict(orient="records") + else: + return {"error": "No candles data available"} + + except Exception as e: + return {"error": str(e)} + + +@router.post("/historical-candles") +async def get_historical_candles(request: Request, config: HistoricalCandlesConfig): + """ + Get historical candles data for a specific trading pair. + + Args: + config: Configuration for historical candles including connector, trading pair, interval, start and end time + + Returns: + Historical candles data or error message + """ + try: + market_data_feed_manager: MarketDataFeedManager = request.app.state.market_data_feed_manager + + # Create candles config from historical config + candles_config = CandlesConfig( + connector=config.connector_name, + trading_pair=config.trading_pair, + interval=config.interval + ) + + # Get or create the candles feed (this will track access time) + candles = market_data_feed_manager.get_candles_feed(candles_config) + + # Fetch historical candles + historical_data = await candles.get_historical_candles(config=config) + + if historical_data is not None and not historical_data.empty: + # Convert to dict for JSON serialization + return historical_data.to_dict(orient="records") + else: + return {"error": "No historical data available"} + + except Exception as e: + return {"error": str(e)} + + +@router.get("/active-feeds") +async def get_active_feeds(request: Request): + """ + Get information about currently active market data feeds. + + Args: + request: FastAPI request object to access application state + + Returns: + Dictionary with active feeds information including last access times and expiration + """ + try: + market_data_feed_manager: MarketDataFeedManager = request.app.state.market_data_feed_manager + return market_data_feed_manager.get_active_feeds_info() + except Exception as e: + return {"error": str(e)} + + +@router.get("/settings") +async def get_market_data_settings(): + """ + Get current market data settings for debugging. + + Returns: + Dictionary with current market data configuration including cleanup and timeout settings + """ + from config import settings + return { + "cleanup_interval": settings.market_data.cleanup_interval, + "feed_timeout": settings.market_data.feed_timeout, + "description": "cleanup_interval: seconds between cleanup runs, feed_timeout: seconds before unused feeds expire" + } + + +@router.get("/available-candle-connectors") +async def get_available_candle_connectors(): + """ + Get list of available connectors that support candle data feeds. + + Returns: + List of connector names that can be used for fetching candle data + """ + return list(CandlesFactory._candles_map.keys()) + + +# Enhanced Market Data Endpoints + +@router.post("/prices", response_model=PricesResponse) +async def get_prices( + request: PriceRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get current prices for specified trading pairs from a connector. + + Args: + request: Price request with connector name and trading pairs + market_data_manager: Injected market data feed manager + + Returns: + Current prices for the specified trading pairs + + Raises: + HTTPException: 500 if there's an error fetching prices + """ + try: + prices = await market_data_manager.get_prices( + request.connector_name, + request.trading_pairs + ) + + if "error" in prices: + raise HTTPException(status_code=500, detail=prices["error"]) + + return PricesResponse( + connector=request.connector_name, + prices=prices, + timestamp=time.time() + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching prices: {str(e)}") + + +@router.post("/funding-info", response_model=FundingInfoResponse) +async def get_funding_info( + request: FundingInfoRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get funding information for a perpetual trading pair. + + Args: + request: Funding info request with connector name and trading pair + market_data_manager: Injected market data feed manager + + Returns: + Funding information including rates, timestamps, and prices + + Raises: + HTTPException: 400 for non-perpetual connectors, 500 for other errors + """ + try: + if "_perpetual" not in request.connector_name.lower(): + raise HTTPException(status_code=400, detail="Funding info is only available for perpetual trading pairs.") + funding_info = await market_data_manager.get_funding_info( + request.connector_name, + request.trading_pair + ) + + if "error" in funding_info: + if "not supported" in funding_info["error"]: + raise HTTPException(status_code=400, detail=funding_info["error"]) + else: + raise HTTPException(status_code=500, detail=funding_info["error"]) + + return FundingInfoResponse(**funding_info) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching funding info: {str(e)}") + + +@router.post("/order-book", response_model=OrderBookResponse) +async def get_order_book( + request: OrderBookRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get order book snapshot with specified depth. + + Args: + request: Order book request with connector, trading pair, and depth + market_data_manager: Injected market data feed manager + + Returns: + Order book snapshot with bids and asks + + Raises: + HTTPException: 500 if there's an error fetching order book + """ + try: + order_book_data = await market_data_manager.get_order_book_data( + request.connector_name, + request.trading_pair, + request.depth + ) + + if "error" in order_book_data: + raise HTTPException(status_code=500, detail=order_book_data["error"]) + + # Convert to response format - data comes as [price, amount] lists + bids = [OrderBookLevel(price=bid[0], amount=bid[1]) for bid in order_book_data["bids"]] + asks = [OrderBookLevel(price=ask[0], amount=ask[1]) for ask in order_book_data["asks"]] + + return OrderBookResponse( + trading_pair=order_book_data["trading_pair"], + bids=bids, + asks=asks, + timestamp=order_book_data["timestamp"] + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching order book: {str(e)}") + + +# Order Book Query Endpoints + +@router.post("/order-book/price-for-volume", response_model=OrderBookQueryResult) +async def get_price_for_volume( + request: PriceForVolumeRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get the price required to fill a specific volume on the order book. + + Args: + request: Request with connector, trading pair, volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with price and volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + volume=request.volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/volume-for-price", response_model=OrderBookQueryResult) +async def get_volume_for_price( + request: VolumeForPriceRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get the volume available at a specific price level on the order book. + + Args: + request: Request with connector, trading pair, price, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + price=request.price + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/price-for-quote-volume", response_model=OrderBookQueryResult) +async def get_price_for_quote_volume( + request: PriceForQuoteVolumeRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get the price required to fill a specific quote volume on the order book. + + Args: + request: Request with connector, trading pair, quote volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with price and volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + quote_volume=request.quote_volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/quote-volume-for-price", response_model=OrderBookQueryResult) +async def get_quote_volume_for_price( + request: QuoteVolumeForPriceRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get the quote volume available at a specific price level on the order book. + + Args: + request: Request with connector, trading pair, price, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with quote volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + quote_price=request.price + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/vwap-for-volume", response_model=OrderBookQueryResult) +async def get_vwap_for_volume( + request: VWAPForVolumeRequest, + market_data_manager: MarketDataFeedManager = Depends(get_market_data_feed_manager) +): + """ + Get the VWAP (Volume Weighted Average Price) for a specific volume on the order book. + + Args: + request: Request with connector, trading pair, volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with VWAP information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + vwap_volume=request.volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + diff --git a/routers/portfolio.py b/routers/portfolio.py new file mode 100644 index 00000000..4a5b69bd --- /dev/null +++ b/routers/portfolio.py @@ -0,0 +1,331 @@ +from typing import Dict, List, Optional +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Depends + +from models.trading import ( + PortfolioStateFilterRequest, + PortfolioHistoryFilterRequest, + PortfolioDistributionFilterRequest, + AccountsDistributionFilterRequest +) +from services.accounts_service import AccountsService +from deps import get_accounts_service +from models import PaginatedResponse + +router = APIRouter(tags=["Portfolio"], prefix="/portfolio") + + +@router.post("/state", response_model=Dict[str, Dict[str, List[Dict]]]) +async def get_portfolio_state( + filter_request: PortfolioStateFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get the current state of all or filtered accounts portfolio. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Dict containing account states with connector balances and token information + """ + await accounts_service.update_account_state() + all_states = accounts_service.get_accounts_state() + + # Apply account name filter first + if filter_request.account_names: + filtered_states = {} + for account_name in filter_request.account_names: + if account_name in all_states: + filtered_states[account_name] = all_states[account_name] + all_states = filtered_states + + # Apply connector filter if specified + if filter_request.connector_names: + for account_name, account_data in all_states.items(): + # Filter connectors directly (they are at the top level of account_data) + filtered_connectors = {} + for connector_name in filter_request.connector_names: + if connector_name in account_data: + filtered_connectors[connector_name] = account_data[connector_name] + # Replace account_data with only filtered connectors + all_states[account_name] = filtered_connectors + + return all_states + + +@router.post("/history", response_model=PaginatedResponse) +async def get_portfolio_history( + filter_request: PortfolioHistoryFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get the historical state of all or filtered accounts portfolio with pagination. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with historical portfolio data + """ + try: + # Convert integer timestamps to datetime objects + start_time_dt = datetime.fromtimestamp(filter_request.start_time / 1000) if filter_request.start_time else None + end_time_dt = datetime.fromtimestamp(filter_request.end_time / 1000) if filter_request.end_time else None + + if not filter_request.account_names: + # Get history for all accounts + data, next_cursor, has_more = await accounts_service.load_account_state_history( + limit=filter_request.limit, + cursor=filter_request.cursor, + start_time=start_time_dt, + end_time=end_time_dt + ) + else: + # Get history for specific accounts - need to aggregate + all_data = [] + for account_name in filter_request.account_names: + acc_data, _, _ = await accounts_service.get_account_state_history( + account_name=account_name, + limit=filter_request.limit, + cursor=filter_request.cursor, + start_time=start_time_dt, + end_time=end_time_dt + ) + all_data.extend(acc_data) + + # Sort by timestamp and apply pagination + all_data.sort(key=lambda x: x.get("timestamp", ""), reverse=True) + + # Apply limit + data = all_data[:filter_request.limit] + has_more = len(all_data) > filter_request.limit + next_cursor = data[-1]["timestamp"] if data and has_more else None + + # Apply connector filter to the data if specified + if filter_request.connector_names: + for item in data: + for account_name, account_data in item.items(): + if isinstance(account_data, dict) and "connectors" in account_data: + filtered_connectors = {} + for connector_name in filter_request.connector_names: + if connector_name in account_data["connectors"]: + filtered_connectors[connector_name] = account_data["connectors"][connector_name] + account_data["connectors"] = filtered_connectors + + return PaginatedResponse( + data=data, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "current_cursor": filter_request.cursor, + "filters": { + "account_names": filter_request.account_names, + "connector_names": filter_request.connector_names, + "start_time": filter_request.start_time, + "end_time": filter_request.end_time + } + } + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/distribution") +async def get_portfolio_distribution( + filter_request: PortfolioDistributionFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get portfolio distribution by tokens with percentages across all or filtered accounts. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Dictionary with token distribution including percentages, values, and breakdown by accounts/connectors + """ + if not filter_request.account_names: + # Get distribution for all accounts + distribution = accounts_service.get_portfolio_distribution() + elif len(filter_request.account_names) == 1: + # Single account - use existing method + distribution = accounts_service.get_portfolio_distribution(filter_request.account_names[0]) + else: + # Multiple accounts - need to aggregate + aggregated_distribution = { + "tokens": {}, + "total_value": 0, + "token_count": 0, + "accounts": {} + } + + for account_name in filter_request.account_names: + account_dist = accounts_service.get_portfolio_distribution(account_name) + + # Skip if account doesn't exist or has error + if account_dist.get("error") or account_dist.get("token_count", 0) == 0: + continue + + # Aggregate token data + for token, token_data in account_dist.get("tokens", {}).items(): + if token not in aggregated_distribution["tokens"]: + aggregated_distribution["tokens"][token] = { + "token": token, + "value": 0, + "percentage": 0, + "accounts": {} + } + + aggregated_distribution["tokens"][token]["value"] += token_data.get("value", 0) + + # Copy account-specific data + for acc_name, acc_data in token_data.get("accounts", {}).items(): + aggregated_distribution["tokens"][token]["accounts"][acc_name] = acc_data + + aggregated_distribution["total_value"] += account_dist.get("total_value", 0) + aggregated_distribution["accounts"][account_name] = account_dist.get("accounts", {}).get(account_name, {}) + + # Recalculate percentages + total_value = aggregated_distribution["total_value"] + if total_value > 0: + for token_data in aggregated_distribution["tokens"].values(): + token_data["percentage"] = (token_data["value"] / total_value) * 100 + + aggregated_distribution["token_count"] = len(aggregated_distribution["tokens"]) + + distribution = aggregated_distribution + + # Apply connector filter if specified + if filter_request.connector_names: + filtered_distribution = [] + filtered_total_value = 0 + + for token_data in distribution.get("distribution", []): + filtered_token = { + "token": token_data["token"], + "total_value": 0, + "total_units": 0, + "percentage": 0, + "accounts": {} + } + + # Filter each account's connectors + for account_name, account_data in token_data.get("accounts", {}).items(): + if "connectors" in account_data: + filtered_connectors = {} + account_value = 0 + account_units = 0 + + # Only include specified connectors + for connector_name in filter_request.connector_names: + if connector_name in account_data["connectors"]: + filtered_connectors[connector_name] = account_data["connectors"][connector_name] + account_value += account_data["connectors"][connector_name].get("value", 0) + account_units += account_data["connectors"][connector_name].get("units", 0) + + # Only include account if it has matching connectors + if filtered_connectors: + filtered_token["accounts"][account_name] = { + "value": round(account_value, 6), + "units": account_units, + "percentage": 0, # Will be recalculated later + "connectors": filtered_connectors + } + + filtered_token["total_value"] += account_value + filtered_token["total_units"] += account_units + + # Only include token if it has values after filtering + if filtered_token["total_value"] > 0: + filtered_distribution.append(filtered_token) + filtered_total_value += filtered_token["total_value"] + + # Recalculate percentages after filtering + if filtered_total_value > 0: + for token_data in filtered_distribution: + token_data["percentage"] = round((token_data["total_value"] / filtered_total_value) * 100, 4) + # Update account percentages + for account_data in token_data["accounts"].values(): + account_data["percentage"] = round((account_data["value"] / filtered_total_value) * 100, 4) + + # Sort by value (descending) + filtered_distribution.sort(key=lambda x: x["total_value"], reverse=True) + + # Update the distribution + distribution = { + "total_portfolio_value": round(filtered_total_value, 6), + "token_count": len(filtered_distribution), + "distribution": filtered_distribution, + "account_filter": distribution.get("account_filter", "filtered") + } + + return distribution + + +@router.post("/accounts-distribution") +async def get_accounts_distribution( + filter_request: AccountsDistributionFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get portfolio distribution by accounts with percentages. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Dictionary with account distribution including percentages, values, and breakdown by connectors + """ + all_distribution = accounts_service.get_account_distribution() + + # If no filter, return all accounts + if not filter_request.account_names: + return all_distribution + + # Filter the distribution by requested accounts + filtered_distribution = { + "accounts": {}, + "total_value": 0, + "account_count": 0 + } + + for account_name in filter_request.account_names: + if account_name in all_distribution.get("accounts", {}): + filtered_distribution["accounts"][account_name] = all_distribution["accounts"][account_name] + filtered_distribution["total_value"] += all_distribution["accounts"][account_name].get("total_value", 0) + + # Apply connector filter if specified + if filter_request.connector_names: + for account_name, account_data in filtered_distribution["accounts"].items(): + if "connectors" in account_data: + filtered_connectors = {} + for connector_name in filter_request.connector_names: + if connector_name in account_data["connectors"]: + filtered_connectors[connector_name] = account_data["connectors"][connector_name] + account_data["connectors"] = filtered_connectors + + # Recalculate account total after connector filtering + new_total = sum( + conn_data.get("total_balance_in_usd", 0) + for conn_data in filtered_connectors.values() + ) + account_data["total_value"] = new_total + + # Recalculate total_value after connector filtering + filtered_distribution["total_value"] = sum( + acc_data.get("total_value", 0) + for acc_data in filtered_distribution["accounts"].values() + ) + + # Recalculate percentages + total_value = filtered_distribution["total_value"] + if total_value > 0: + for account_data in filtered_distribution["accounts"].values(): + account_data["percentage"] = (account_data.get("total_value", 0) / total_value) * 100 + + filtered_distribution["account_count"] = len(filtered_distribution["accounts"]) + + return filtered_distribution \ No newline at end of file diff --git a/routers/scripts.py b/routers/scripts.py new file mode 100644 index 00000000..c4c18409 --- /dev/null +++ b/routers/scripts.py @@ -0,0 +1,214 @@ +import json +import yaml +from typing import Dict, List + +from fastapi import APIRouter, HTTPException +from starlette import status + +from models import Script, ScriptConfig +from utils.file_system import fs_util + +router = APIRouter(tags=["Scripts"], prefix="/scripts") + + +@router.get("/", response_model=List[str]) +async def list_scripts(): + """ + List all available scripts. + + Returns: + List of script names (without .py extension) + """ + return [f.replace('.py', '') for f in fs_util.list_files('scripts') if f.endswith('.py')] + + +# Script Configuration endpoints (must come before script name routes) +@router.get("/configs/", response_model=List[Dict]) +async def list_script_configs(): + """ + List all script configurations with metadata. + + Returns: + List of script configuration objects with name, script_file_name, and other metadata + """ + try: + config_files = [f for f in fs_util.list_files('conf/scripts') if f.endswith('.yml')] + configs = [] + + for config_file in config_files: + config_name = config_file.replace('.yml', '') + try: + config = fs_util.read_yaml_file(f"conf/scripts/{config_file}") + configs.append({ + "config_name": config_name, + "script_file_name": config.get("script_file_name", "unknown"), + "controllers_config": config.get("controllers_config", []), + "candles_config": config.get("candles_config", []), + "markets": config.get("markets", {}) + }) + except Exception as e: + # If config is malformed, still include it with basic info + configs.append({ + "config_name": config_name, + "script_file_name": "error", + "error": str(e) + }) + + return configs + except FileNotFoundError: + return [] + + +@router.get("/configs/{config_name}", response_model=Dict) +async def get_script_config(config_name: str): + """ + Get script configuration by config name. + + Args: + config_name: Name of the configuration file to retrieve + + Returns: + Dictionary with script configuration + + Raises: + HTTPException: 404 if configuration not found + """ + try: + config = fs_util.read_yaml_file(f"conf/scripts/{config_name}.yml") + return config + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.post("/configs/{config_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_script_config(config_name: str, config: Dict): + """ + Create or update script configuration. + + Args: + config_name: Name of the configuration file + config: Configuration dictionary to save + + Returns: + Success message when configuration is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + yaml_content = yaml.dump(config, default_flow_style=False) + fs_util.add_file('conf/scripts', f"{config_name}.yml", yaml_content, override=True) + return {"message": f"Configuration '{config_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/configs/{config_name}") +async def delete_script_config(config_name: str): + """ + Delete script configuration. + + Args: + config_name: Name of the configuration file to delete + + Returns: + Success message when configuration is deleted + + Raises: + HTTPException: 404 if configuration not found + """ + try: + fs_util.delete_file('conf/scripts', f"{config_name}.yml") + return {"message": f"Configuration '{config_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.get("/{script_name}", response_model=Dict[str, str]) +async def get_script(script_name: str): + """ + Get script content by name. + + Args: + script_name: Name of the script to retrieve + + Returns: + Dictionary with script name and content + + Raises: + HTTPException: 404 if script not found + """ + try: + content = fs_util.read_file(f"scripts/{script_name}.py") + return { + "name": script_name, + "content": content + } + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Script '{script_name}' not found") + + +@router.post("/{script_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_script(script_name: str, script: Script): + """ + Create or update a script. + + Args: + script_name: Name of the script (from URL path) + script: Script object with content + + Returns: + Success message when script is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + fs_util.add_file('scripts', f"{script_name}.py", script.content, override=True) + return {"message": f"Script '{script_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/{script_name}") +async def delete_script(script_name: str): + """ + Delete a script. + + Args: + script_name: Name of the script to delete + + Returns: + Success message when script is deleted + + Raises: + HTTPException: 404 if script not found + """ + try: + fs_util.delete_file('scripts', f"{script_name}.py") + return {"message": f"Script '{script_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Script '{script_name}' not found") + + +@router.get("/{script_name}/config/template", response_model=Dict) +async def get_script_config_template(script_name: str): + """ + Get script configuration template with default values. + + Args: + script_name: Name of the script to get template for + + Returns: + Dictionary with configuration template and default values + + Raises: + HTTPException: 404 if script configuration class not found + """ + config_class = fs_util.load_script_config_class(script_name) + if config_class is None: + raise HTTPException(status_code=404, detail=f"Script configuration class for '{script_name}' not found") + + # Extract fields and default values + config_fields = {name: field.default for name, field in config_class.model_fields.items()} + return json.loads(json.dumps(config_fields, default=str)) \ No newline at end of file diff --git a/routers/trading.py b/routers/trading.py new file mode 100644 index 00000000..adf8e700 --- /dev/null +++ b/routers/trading.py @@ -0,0 +1,752 @@ +import logging +import math + +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException + +# Create module-specific logger +logger = logging.getLogger(__name__) +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType +from pydantic import BaseModel +from starlette import status + +from deps import get_accounts_service, get_market_data_feed_manager +from models import ( + ActiveOrderFilterRequest, + FundingPaymentFilterRequest, + OrderFilterRequest, + PaginatedResponse, + PositionFilterRequest, + TradeFilterRequest, + TradeRequest, + TradeResponse, +) +from models.accounts import LeverageRequest, PositionModeRequest +from services.accounts_service import AccountsService + +router = APIRouter(tags=["Trading"], prefix="/trading") + + +# Trade Execution +@router.post("/orders", response_model=TradeResponse, status_code=status.HTTP_201_CREATED) +async def place_trade( + trade_request: TradeRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + market_data_manager=Depends(get_market_data_feed_manager), +): + """ + Place a buy or sell order using a specific account and connector. + + Args: + trade_request: Trading request with account, connector, trading pair, type, amount, etc. + accounts_service: Injected accounts service + market_data_manager: Market data manager for price fetching + + Returns: + TradeResponse with order ID and trading details + + Raises: + HTTPException: 400 for invalid parameters, 404 for account/connector not found, 500 for trade execution errors + """ + try: + # Convert string names to enum instances + trade_type_enum = TradeType[trade_request.trade_type] + order_type_enum = OrderType[trade_request.order_type] + position_action_enum = PositionAction[trade_request.position_action] + + order_id = await accounts_service.place_trade( + account_name=trade_request.account_name, + connector_name=trade_request.connector_name, + trading_pair=trade_request.trading_pair, + trade_type=trade_type_enum, + amount=trade_request.amount, + order_type=order_type_enum, + price=trade_request.price, + position_action=position_action_enum, + market_data_manager=market_data_manager, + ) + + return TradeResponse( + order_id=order_id, + account_name=trade_request.account_name, + connector_name=trade_request.connector_name, + trading_pair=trade_request.trading_pair, + trade_type=trade_request.trade_type, + amount=trade_request.amount, + order_type=trade_request.order_type, + price=trade_request.price, + status="submitted", + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Unexpected error placing trade: {str(e)}") + + +@router.post("/{account_name}/{connector_name}/orders/{client_order_id}/cancel") +async def cancel_order( + account_name: str, + connector_name: str, + client_order_id: str, + accounts_service: AccountsService = Depends(get_accounts_service), +): + """ + Cancel a specific order by its client order ID. + + Args: + account_name: Name of the account + connector_name: Name of the connector + client_order_id: Client order ID to cancel + trading_pair: Trading pair for the order + accounts_service: Injected accounts service + + Returns: + Success message with cancelled order ID + + Raises: + HTTPException: 404 if account/connector not found, 500 for cancellation errors + """ + try: + cancelled_order_id = await accounts_service.cancel_order( + account_name=account_name, connector_name=connector_name, client_order_id=client_order_id + ) + return {"message": f"Order cancellation initiated for {cancelled_order_id}"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error cancelling order: {str(e)}") + + +@router.post("/positions", response_model=PaginatedResponse) +async def get_positions(filter_request: PositionFilterRequest, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get current positions across all or filtered perpetual connectors. + + This endpoint fetches real-time position data directly from the connectors, + including unrealized PnL, leverage, funding fees, and margin information. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with position data and pagination metadata + + Raises: + HTTPException: 500 if there's an error fetching positions + """ + try: + all_positions = [] + all_connectors = accounts_service.connector_manager.get_all_connectors() + + # Filter accounts + accounts_to_check = filter_request.account_names if filter_request.account_names else list(all_connectors.keys()) + + for account_name in accounts_to_check: + if account_name in all_connectors: + # Filter connectors + connectors_to_check = ( + filter_request.connector_names + if filter_request.connector_names + else list(all_connectors[account_name].keys()) + ) + + for connector_name in connectors_to_check: + # Only fetch positions from perpetual connectors + if connector_name in all_connectors[account_name] and "_perpetual" in connector_name: + try: + positions = await accounts_service.get_account_positions(account_name, connector_name) + # Add cursor-friendly identifier to each position + for position in positions: + position["_cursor_id"] = f"{account_name}:{connector_name}:{position.get('trading_pair', '')}" + all_positions.extend(positions) + except Exception as e: + # Log error but continue with other connectors + import logging + + logger.warning(f"Failed to get positions for {account_name}/{connector_name}: {e}") + + # Sort by cursor_id for consistent pagination + all_positions.sort(key=lambda x: x.get("_cursor_id", "")) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the position after the cursor + for i, position in enumerate(all_positions): + if position.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_positions = all_positions[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_positions) + next_cursor = page_positions[-1].get("_cursor_id") if page_positions and has_more else None + + # Clean up cursor_id from response data + for position in page_positions: + position.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_positions, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_positions), + }, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching positions: {str(e)}") + + +# Active Orders Management - Real-time from connectors +@router.post("/orders/active", response_model=PaginatedResponse) +async def get_active_orders( + filter_request: ActiveOrderFilterRequest, accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get active (in-flight) orders across all or filtered accounts and connectors. + + This endpoint fetches real-time active orders directly from the connectors' in_flight_orders property, + providing current order status, fill amounts, and other live order data. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with active order data and pagination metadata + + Raises: + HTTPException: 500 if there's an error fetching orders + """ + try: + all_active_orders = [] + all_connectors = accounts_service.connector_manager.get_all_connectors() + + # Use filter request values + accounts_to_check = filter_request.account_names if filter_request.account_names else list(all_connectors.keys()) + + for account_name in accounts_to_check: + if account_name in all_connectors: + # Filter connectors + connectors_to_check = ( + filter_request.connector_names + if filter_request.connector_names + else list(all_connectors[account_name].keys()) + ) + + for connector_name in connectors_to_check: + if connector_name in all_connectors[account_name]: + try: + connector = all_connectors[account_name][connector_name] + # Get in-flight orders directly from connector + in_flight_orders = connector.in_flight_orders + + for client_order_id, order in in_flight_orders.items(): + # Apply trading pair filter if specified + if filter_request.trading_pairs and order.trading_pair not in filter_request.trading_pairs: + continue + + # Convert to standardized format to match orders search response + standardized_order = _standardize_in_flight_order_response(order, account_name, connector_name) + standardized_order["_cursor_id"] = client_order_id # Use client_order_id as cursor + all_active_orders.append(standardized_order) + + except Exception as e: + # Log error but continue with other connectors + import logging + + logger.warning(f"Failed to get active orders for {account_name}/{connector_name}: {e}") + + # Sort by cursor_id for consistent pagination + all_active_orders.sort(key=lambda x: x.get("_cursor_id", "")) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the order after the cursor + for i, order in enumerate(all_active_orders): + if order.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_orders = all_active_orders[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_active_orders) + next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None + + # Clean up cursor_id from response data + for order in page_orders: + order.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_orders, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_active_orders), + }, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching active orders: {str(e)}") + + +# Historical Order Management - From registry/database +@router.post("/orders/search", response_model=PaginatedResponse) +async def get_orders(filter_request: OrderFilterRequest, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get historical order data across all or filtered accounts from the database/registry. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with historical order data and pagination metadata + """ + try: + all_orders = [] + + # Determine which accounts to query + if filter_request.account_names: + accounts_to_check = filter_request.account_names + else: + # Get all accounts + all_connectors = accounts_service.connector_manager.get_all_connectors() + accounts_to_check = list(all_connectors.keys()) + + # Collect orders from all specified accounts + for account_name in accounts_to_check: + try: + orders = await accounts_service.get_orders( + account_name=account_name, + connector_name=( + filter_request.connector_names[0] + if filter_request.connector_names and len(filter_request.connector_names) == 1 + else None + ), + trading_pair=( + filter_request.trading_pairs[0] + if filter_request.trading_pairs and len(filter_request.trading_pairs) == 1 + else None + ), + status=filter_request.status, + start_time=filter_request.start_time, + end_time=filter_request.end_time, + limit=filter_request.limit * 2, # Get more for filtering + offset=0, + ) + # Add cursor-friendly identifier to each order + for order in orders: + order["_cursor_id"] = f"{order.get('timestamp', 0)}:{order.get('client_order_id', '')}" + all_orders.extend(orders) + except Exception as e: + # Log error but continue with other accounts + import logging + + logger.warning(f"Failed to get orders for {account_name}: {e}") + + # Apply filters for multiple values + if filter_request.connector_names and len(filter_request.connector_names) > 1: + all_orders = [order for order in all_orders if order.get("connector_name") in filter_request.connector_names] + if filter_request.trading_pairs and len(filter_request.trading_pairs) > 1: + all_orders = [order for order in all_orders if order.get("trading_pair") in filter_request.trading_pairs] + + # Sort by timestamp (most recent first) and then by cursor_id for consistency + all_orders.sort(key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), reverse=True) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the order after the cursor + for i, order in enumerate(all_orders): + if order.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_orders = all_orders[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_orders) + next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None + + # Clean up cursor_id from response data + for order in page_orders: + order.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_orders, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_orders), + }, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching orders: {str(e)}") + + +# Trade History +@router.post("/trades", response_model=PaginatedResponse) +async def get_trades(filter_request: TradeFilterRequest, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get trade history across all or filtered accounts with complex filtering. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with trade data and pagination metadata + """ + try: + all_trades = [] + + # Determine which accounts to query + if filter_request.account_names: + accounts_to_check = filter_request.account_names + else: + # Get all accounts + all_connectors = accounts_service.connector_manager.get_all_connectors() + accounts_to_check = list(all_connectors.keys()) + + # Collect trades from all specified accounts + for account_name in accounts_to_check: + try: + trades = await accounts_service.get_trades( + account_name=account_name, + connector_name=( + filter_request.connector_names[0] + if filter_request.connector_names and len(filter_request.connector_names) == 1 + else None + ), + trading_pair=( + filter_request.trading_pairs[0] + if filter_request.trading_pairs and len(filter_request.trading_pairs) == 1 + else None + ), + trade_type=( + filter_request.trade_types[0] + if filter_request.trade_types and len(filter_request.trade_types) == 1 + else None + ), + start_time=filter_request.start_time, + end_time=filter_request.end_time, + limit=filter_request.limit * 2, # Get more for filtering + offset=0, + ) + # Add cursor-friendly identifier to each trade + for trade in trades: + trade["_cursor_id"] = f"{trade.get('timestamp', 0)}:{trade.get('trade_id', '')}" + all_trades.extend(trades) + except Exception as e: + # Log error but continue with other accounts + import logging + + logger.warning(f"Failed to get trades for {account_name}: {e}") + + # Apply filters for multiple values + if filter_request.connector_names and len(filter_request.connector_names) > 1: + all_trades = [trade for trade in all_trades if trade.get("connector_name") in filter_request.connector_names] + if filter_request.trading_pairs and len(filter_request.trading_pairs) > 1: + all_trades = [trade for trade in all_trades if trade.get("trading_pair") in filter_request.trading_pairs] + if filter_request.trade_types and len(filter_request.trade_types) > 1: + all_trades = [trade for trade in all_trades if trade.get("trade_type") in filter_request.trade_types] + + # Sort by timestamp (most recent first) and then by cursor_id for consistency + all_trades.sort(key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), reverse=True) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the trade after the cursor + for i, trade in enumerate(all_trades): + if trade.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_trades = all_trades[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_trades) + next_cursor = page_trades[-1].get("_cursor_id") if page_trades and has_more else None + + # Clean up cursor_id from response data + for trade in page_trades: + trade.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_trades, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_trades), + }, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching trades: {str(e)}") + + +@router.post("/{account_name}/{connector_name}/position-mode") +async def set_position_mode( + account_name: str, + connector_name: str, + request: PositionModeRequest, + accounts_service: AccountsService = Depends(get_accounts_service), +): + """ + Set position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the perpetual connector + position_mode: Position mode to set (HEDGE or ONEWAY) + + Returns: + Success message with status + + Raises: + HTTPException: 400 if not a perpetual connector or invalid position mode + """ + try: + # Convert string to PositionMode enum + mode = PositionMode[request.position_mode.upper()] + result = await accounts_service.set_position_mode(account_name, connector_name, mode) + return result + except KeyError: + raise HTTPException( + status_code=400, detail=f"Invalid position mode '{request.position_mode}'. Must be 'HEDGE' or 'ONEWAY'" + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{account_name}/{connector_name}/position-mode") +async def get_position_mode( + account_name: str, connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get current position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the perpetual connector + + Returns: + Dictionary with current position mode, connector name, and account name + + Raises: + HTTPException: 400 if not a perpetual connector + """ + try: + result = await accounts_service.get_position_mode(account_name, connector_name) + return result + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{account_name}/{connector_name}/leverage") +async def set_leverage( + account_name: str, + connector_name: str, + request: LeverageRequest, + accounts_service: AccountsService = Depends(get_accounts_service), +): + """ + Set leverage for a specific trading pair on a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the perpetual connector + request: Leverage request with trading pair and leverage value + accounts_service: Injected accounts service + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: 400 for invalid parameters or non-perpetual connector, 404 for account/connector not found, 500 for execution errors + """ + try: + result = await accounts_service.set_leverage( + account_name=account_name, connector_name=connector_name, trading_pair=request.trading_pair, leverage=request.leverage + ) + return result + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Unexpected error setting leverage: {str(e)}") + + +@router.post("/funding-payments", response_model=PaginatedResponse) +async def get_funding_payments( + filter_request: FundingPaymentFilterRequest, accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get funding payment history across all or filtered perpetual connectors. + + This endpoint retrieves historical funding payment records including + funding rates, payment amounts, and position data at time of payment. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with funding payment data and pagination metadata + + Raises: + HTTPException: 500 if there's an error fetching funding payments + """ + try: + all_funding_payments = [] + all_connectors = accounts_service.connector_manager.get_all_connectors() + + # Filter accounts + accounts_to_check = filter_request.account_names if filter_request.account_names else list(all_connectors.keys()) + + for account_name in accounts_to_check: + if account_name in all_connectors: + # Filter connectors + connectors_to_check = ( + filter_request.connector_names + if filter_request.connector_names + else list(all_connectors[account_name].keys()) + ) + + for connector_name in connectors_to_check: + # Only fetch funding payments from perpetual connectors + if connector_name in all_connectors[account_name] and "_perpetual" in connector_name: + try: + payments = await accounts_service.get_funding_payments( + account_name=account_name, + connector_name=connector_name, + trading_pair=filter_request.trading_pair, + limit=filter_request.limit * 2, # Get more for pagination + ) + # Add cursor-friendly identifier to each payment + for payment in payments: + payment["_cursor_id"] = ( + f"{account_name}:{connector_name}:{payment.get('timestamp', '')}:{payment.get('trading_pair', '')}" + ) + all_funding_payments.extend(payments) + except Exception as e: + # Log error but continue with other connectors + import logging + + logger.warning(f"Failed to get funding payments for {account_name}/{connector_name}: {e}") + + # Sort by timestamp (most recent first) and then by cursor_id for consistency + all_funding_payments.sort(key=lambda x: (x.get("timestamp", ""), x.get("_cursor_id", "")), reverse=True) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the payment after the cursor + for i, payment in enumerate(all_funding_payments): + if payment.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_payments = all_funding_payments[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_funding_payments) + next_cursor = page_payments[-1].get("_cursor_id") if page_payments and has_more else None + + # Clean up cursor_id from response data + for payment in page_payments: + payment.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_payments, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_funding_payments), + }, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching funding payments: {str(e)}") + + +def _standardize_in_flight_order_response(order, account_name: str, connector_name: str) -> dict: + """ + Convert a Hummingbot InFlightOrder to standardized format matching the orders search response. + + Args: + order: Hummingbot InFlightOrder instance + account_name: Name of the account + connector_name: Name of the connector + + Returns: + Dictionary with standardized order format + """ + # Map OrderState to status strings + from hummingbot.core.data_type.in_flight_order import OrderState + + status_mapping = { + OrderState.PENDING_CREATE: "SUBMITTED", + OrderState.OPEN: "OPEN", + OrderState.PENDING_CANCEL: "OPEN", # Still open until cancelled + OrderState.CANCELED: "CANCELLED", + OrderState.PARTIALLY_FILLED: "PARTIALLY_FILLED", + OrderState.FILLED: "FILLED", + OrderState.FAILED: "FAILED", + OrderState.PENDING_APPROVAL: "SUBMITTED", + OrderState.APPROVED: "SUBMITTED", + OrderState.CREATED: "SUBMITTED", + OrderState.COMPLETED: "FILLED", + } + + # Get status string + status = status_mapping.get(order.current_state, "SUBMITTED") + + # Convert timestamps to ISO format + from datetime import datetime, timezone + + created_at = datetime.fromtimestamp(order.creation_timestamp, tz=timezone.utc).isoformat() + updated_at = datetime.fromtimestamp( + getattr(order, "last_update_timestamp", order.creation_timestamp), tz=timezone.utc + ).isoformat() + + return { + "order_id": order.client_order_id, + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": order.trading_pair, + "trade_type": order.trade_type.name, + "order_type": order.order_type.name, + "amount": float(order.amount) if order.amount and not math.isnan(float(order.amount)) else 0, + "price": float(order.price) if order.price and not math.isnan(float(order.price)) else None, + "status": status, + "filled_amount": float(getattr(order, "executed_amount_base", 0) or 0) if not math.isnan(float(getattr(order, "executed_amount_base", 0) or 0)) else 0, + "average_fill_price": float(getattr(order, "last_executed_price", 0)) if getattr(order, "last_executed_price", None) and not math.isnan(float(getattr(order, "last_executed_price", 0))) else None, + "fee_paid": float(getattr(order, "cumulative_fee_paid_quote", 0)) if getattr(order, "cumulative_fee_paid_quote", None) and not math.isnan(float(getattr(order, "cumulative_fee_paid_quote", 0))) else None, + "fee_currency": None, # InFlightOrder doesn't store fee currency directly + "created_at": created_at, + "updated_at": updated_at, + "exchange_order_id": order.exchange_order_id, + "error_message": None, # InFlightOrder doesn't store error messages + } diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..c6b36427 --- /dev/null +++ b/run.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Run script for Backend API +# Usage: ./run.sh [--dev] +# --dev: Run API from source using uvicorn +# Without --dev: Run using docker compose + +if [[ "$1" == "--dev" ]]; then + echo "Running API from source..." + # Activate conda environment and run with uvicorn + docker compose up emqx postgres -d + source "$(conda info --base)/etc/profile.d/conda.sh" + conda activate hummingbot-api + uvicorn main:app --reload +else + echo "Running with Docker Compose..." + docker compose up -d +fi \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py index e69de29b..2bd037e2 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -0,0 +1,9 @@ +from .accounts_service import AccountsService +from .bots_orchestrator import BotsOrchestrator +from .docker_service import DockerService + +__all__ = [ + "AccountsService", + "BotsOrchestrator", + "DockerService", +] \ No newline at end of file diff --git a/services/accounts_service.py b/services/accounts_service.py index 9c824fa9..b847d442 100644 --- a/services/accounts_service.py +++ b/services/accounts_service.py @@ -1,22 +1,22 @@ import asyncio -import json import logging -from datetime import datetime, timedelta +from datetime import datetime, timezone from decimal import Decimal -from typing import Optional +from typing import Dict, List, Optional from fastapi import HTTPException -from hummingbot.client.config.client_config_map import ClientConfigMap from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger -from hummingbot.client.config.config_helpers import ClientConfigAdapter, ReadOnlyClientConfigAdapter, get_connector_class -from hummingbot.client.settings import AllConnectorSettings +from hummingbot.core.data_type.common import OrderType, TradeType, PositionAction, PositionMode +from hummingbot.strategy_v2.executors.data_types import ConnectorPair -from config import BANNED_TOKENS, CONFIG_PASSWORD -from utils.file_system import FileSystemUtil -from utils.models import BackendAPIConfigAdapter -from utils.security import BackendAPISecurity +from config import settings +from database import AsyncDatabaseManager, AccountRepository, OrderRepository, TradeRepository, FundingRepository +from services.market_data_feed_manager import MarketDataFeedManager +from utils.connector_manager import ConnectorManager +from utils.file_system import fs_util -file_system = FileSystemUtil() +# Create module-specific logger +logger = logging.getLogger(__name__) class AccountsService: @@ -25,261 +25,354 @@ class AccountsService: to initialize all the connectors that are connected to each account, keep track of the balances of each account and update the balances of each account. """ + default_quotes = { + "hyperliquid": "USD", + "hyperliquid_perpetual": "USDC", + "xrpl": "RLUSD", + "kraken": "USD", + } + + # Cache for storing last successful prices by trading pair with timestamps + _last_known_prices = {} + _price_update_interval = 60 # Update prices every 60 seconds def __init__(self, - update_account_state_interval_minutes: int = 5, + account_update_interval: int = 5, default_quote: str = "USDT", - account_history_file: str = "account_state_history.json"): - # TODO: Add database to store the balances of each account each time it is updated. - self.secrets_manager = ETHKeyFileSecretManger(CONFIG_PASSWORD) - self.accounts = {} + market_data_feed_manager: Optional[MarketDataFeedManager] = None): + """ + Initialize the AccountsService. + + Args: + account_update_interval: How often to update account states in minutes (default: 5) + default_quote: Default quote currency for trading pairs (default: "USDT") + market_data_feed_manager: Market data feed manager for price caching (optional) + """ + self.secrets_manager = ETHKeyFileSecretManger(settings.security.config_password) self.accounts_state = {} - self.account_state_update_event = asyncio.Event() - self.initialize_accounts() - self.update_account_state_interval = update_account_state_interval_minutes * 60 + self.update_account_state_interval = account_update_interval * 60 self.default_quote = default_quote - self.history_file = account_history_file + self.market_data_feed_manager = market_data_feed_manager self._update_account_state_task: Optional[asyncio.Task] = None + + # Database setup for account states and orders + self.db_manager = AsyncDatabaseManager(settings.database.url) + self._db_initialized = False + + # Initialize connector manager with db_manager + self.connector_manager = ConnectorManager(self.secrets_manager, self.db_manager) + async def ensure_db_initialized(self): + """Ensure database is initialized before using it.""" + if not self._db_initialized: + await self.db_manager.create_tables() + self._db_initialized = True + def get_accounts_state(self): return self.accounts_state - def get_default_market(self, token: str): + def get_default_market(self, token: str, connector_name: str) -> str: if token.startswith("LD") and token != "LDO": # These tokens are staked in binance earn token = token[2:] - return f"{token}-{self.default_quote}" + quote = self.default_quotes.get(connector_name, self.default_quote) + return f"{token}-{quote}" - def start_update_account_state_loop(self): + def start(self): """ - Start the loop that updates the balances of all the accounts at a fixed interval. + Start the loop that updates the account state at a fixed interval. + Note: Balance updates are now handled by manual connector state updates. :return: """ + # Start the update loop which will call check_all_connectors self._update_account_state_task = asyncio.create_task(self.update_account_state_loop()) - def stop_update_account_state_loop(self): + async def stop(self): """ - Stop the loop that updates the balances of all the accounts at a fixed interval. - :return: + Stop all accounts service tasks and cleanup resources. + This is the main cleanup method that should be called during application shutdown. """ + logger.info("Stopping AccountsService...") + + # Stop the account state update loop if self._update_account_state_task: self._update_account_state_task.cancel() - self._update_account_state_task = None + self._update_account_state_task = None + logger.info("Stopped account state update loop") + + # Stop all connectors through the ConnectorManager + await self.connector_manager.stop_all_connectors() + + logger.info("AccountsService stopped successfully") async def update_account_state_loop(self): """ - The loop that updates the balances of all the accounts at a fixed interval. + The loop that updates the account state at a fixed interval. + This now includes manual connector state updates. :return: """ while True: try: await self.check_all_connectors() - await self.update_balances() - await self.update_trading_rules() + # Update all connector states (balances, orders, positions, trading rules) + await self.connector_manager.update_all_connector_states() await self.update_account_state() await self.dump_account_state() except Exception as e: - logging.error(f"Error updating account state: {e}") + logger.error(f"Error updating account state: {e}") finally: await asyncio.sleep(self.update_account_state_interval) async def dump_account_state(self): """ - Dump the current account state to a JSON file. Create it if the file not exists. + Save the current account state to the database. + All account/connector combinations from the same snapshot will use the same timestamp. :return: """ - timestamp = datetime.now().isoformat() - state_to_dump = {"timestamp": timestamp, "state": self.accounts_state} - if not file_system.path_exists(path=f"data/{self.history_file}"): - file_system.add_file(directory="data", file_name=self.history_file, content=json.dumps(state_to_dump) + "\n") - else: - file_system.append_to_file(directory="data", file_name=self.history_file, content=json.dumps(state_to_dump) + "\n") + await self.ensure_db_initialized() + + try: + # Generate a single timestamp for this entire snapshot + snapshot_timestamp = datetime.now(timezone.utc) + + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + + # Save each account-connector combination with the same timestamp + for account_name, connectors in self.accounts_state.items(): + for connector_name, tokens_info in connectors.items(): + if tokens_info: # Only save if there's token data + await repository.save_account_state(account_name, connector_name, tokens_info, snapshot_timestamp) + + except Exception as e: + logger.error(f"Error saving account state to database: {e}") + # Re-raise the exception since we no longer have a fallback + raise - def load_account_state_history(self): + async def load_account_state_history(self, + limit: Optional[int] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None): """ - Load the account state history from the JSON file. - :return: List of account states with timestamps. + Load the account state history from the database with pagination. + :return: Tuple of (data, next_cursor, has_more). """ - history = [] + await self.ensure_db_initialized() + try: - with open("bots/data/" + self.history_file, "r") as file: - for line in file: - if line.strip(): # Check if the line is not empty - history.append(json.loads(line)) - except FileNotFoundError: - logging.warning("No account state history file found.") - return history + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_state_history( + limit=limit, + cursor=cursor, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error loading account state history from database: {e}") + # Return empty result since we no longer have a fallback + return [], None, False async def check_all_connectors(self): """ - Check all avaialble credentials for all accounts and see if the connectors are created. - :return: + Check all available credentials for all accounts and ensure connectors are initialized. + This method is idempotent - it only initializes missing connectors. """ for account_name in self.list_accounts(): - for connector_name in self.list_credentials(account_name): - try: - connector_name = connector_name.split(".")[0] - if account_name not in self.accounts or connector_name not in self.accounts[account_name]: - self.initialize_connector(account_name, connector_name) - except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") - - def initialize_accounts(self): - """ - Initialize all the connectors that are connected to each account. - :return: - """ - for account_name in self.list_accounts(): - self.accounts[account_name] = {} - for connector_name in self.list_credentials(account_name): - try: - connector_name = connector_name.split(".")[0] - connector = self.get_connector(account_name, connector_name) - self.accounts[account_name][connector_name] = connector - except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") + await self._ensure_account_connectors_initialized(account_name) - def initialize_account(self, account_name: str): + async def _ensure_account_connectors_initialized(self, account_name: str): """ - Initialize all the connectors that are connected to the specified account. - :param account_name: The name of the account. - :return: + Ensure all connectors for a specific account are initialized. + This delegates to ConnectorManager for actual initialization. + + :param account_name: The name of the account to initialize connectors for. """ - for connector_name in self.list_credentials(account_name): + # Initialize missing connectors + for connector_name in self.connector_manager.list_available_credentials(account_name): try: - connector_name = connector_name.split(".")[0] - self.initialize_connector(account_name, connector_name) + # Only initialize if connector doesn't exist + if not self.connector_manager.is_connector_initialized(account_name, connector_name): + # Get connector will now handle all initialization + await self.connector_manager.get_connector(account_name, connector_name) except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") + logger.error(f"Error initializing connector {connector_name} for account {account_name}: {e}") - def initialize_connector(self, account_name: str, connector_name: str): + def _initialize_rate_sources_for_pairs(self, connector_name: str, trading_pairs: List[str]): """ - Initialize the specified connector for the specified account. - :param account_name: The name of the account. + Helper method to initialize rate sources for trading pairs. + :param connector_name: The name of the connector. - :return: + :param trading_pairs: List of trading pairs to initialize. """ - if account_name not in self.accounts: - self.accounts[account_name] = {} + if not trading_pairs or not self.market_data_feed_manager: + return + try: - connector = self.get_connector(account_name, connector_name) - self.accounts[account_name][connector_name] = connector + connector_pairs = [ConnectorPair(connector_name=connector_name, trading_pair=trading_pair) + for trading_pair in trading_pairs] + self.market_data_feed_manager.market_data_provider.initialize_rate_sources(connector_pairs) + logger.info(f"Initialized rate sources for {len(trading_pairs)} trading pairs in {connector_name}") except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") - - async def update_balances(self): - tasks = [] - for account_name, connectors in self.accounts.items(): - for connector_instance in connectors.values(): - tasks.append(self._safe_update_balances(connector_instance)) - await asyncio.gather(*tasks) - - async def _safe_update_balances(self, connector_instance): - try: - await connector_instance._update_balances() - except Exception as e: - logging.error(f"Error updating balances for connector {connector_instance}: {e}") - - async def update_trading_rules(self): - tasks = [] - for account_name, connectors in self.accounts.items(): - for connector_instance in connectors.values(): - tasks.append(self._safe_update_trading_rules(connector_instance)) - await asyncio.gather(*tasks) + logger.error(f"Error initializing rate sources for {connector_name}: {e}") - async def _safe_update_trading_rules(self, connector_instance): + async def _initialize_price_tracking(self, account_name: str, connector_name: str, connector): + """ + Initialize price tracking for a connector's tokens using MarketDataProvider. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :param connector: The connector instance. + """ try: - await connector_instance._update_trading_rules() + # Get current balances to determine which tokens need price tracking + balances = connector.get_all_balances() + unique_tokens = [token for token, value in balances.items() if + value != Decimal("0") and token not in settings.banned_tokens and "USD" not in token] + + if unique_tokens: + # Create trading pairs for price tracking + trading_pairs = [self.get_default_market(token, connector_name) for token in unique_tokens] + + # Initialize rate sources using helper method + self._initialize_rate_sources_for_pairs(connector_name, trading_pairs) + + logger.info(f"Initialized price tracking for {len(trading_pairs)} trading pairs in {connector_name} (Account: {account_name})") + except Exception as e: - logging.error(f"Error updating trading rules for connector {connector_instance}: {e}") + logger.error(f"Error initializing price tracking for {connector_name} in account {account_name}: {e}") async def update_account_state(self): - for account_name, connectors in self.accounts.items(): + """Update account state for all connectors.""" + all_connectors = self.connector_manager.get_all_connectors() + + for account_name, connectors in all_connectors.items(): if account_name not in self.accounts_state: self.accounts_state[account_name] = {} for connector_name, connector in connectors.items(): - tokens_info = [] try: - balances = [{"token": key, "units": value} for key, value in connector.get_all_balances().items() if - value != Decimal("0") and key not in BANNED_TOKENS] - unique_tokens = [balance["token"] for balance in balances] - trading_pairs = [self.get_default_market(token) for token in unique_tokens if "USD" not in token] - last_traded_prices = await self._safe_get_last_traded_prices(connector, trading_pairs) - for balance in balances: - token = balance["token"] - if "USD" in token: - price = Decimal("1") - else: - market = self.get_default_market(balance["token"]) - price = Decimal(last_traded_prices.get(market, 0)) - tokens_info.append({ - "token": balance["token"], - "units": float(balance["units"]), - "price": float(price), - "value": float(price * balance["units"]), - "available_units": float(connector.get_available_balance(balance["token"])) - }) - self.account_state_update_event.set() + tokens_info = await self._get_connector_tokens_info(connector, connector_name, self.market_data_feed_manager) + self.accounts_state[account_name][connector_name] = tokens_info except Exception as e: - logging.error( - f"Error updating balances for connector {connector_name} in account {account_name}: {e}") - self.accounts_state[account_name][connector_name] = tokens_info + logger.error(f"Error updating balances for connector {connector_name} in account {account_name}: {e}") + self.accounts_state[account_name][connector_name] = [] - async def _safe_get_last_traded_prices(self, connector, trading_pairs, timeout=5): + async def _get_connector_tokens_info(self, connector, connector_name: str, market_data_manager: Optional[MarketDataFeedManager] = None) -> List[Dict]: + """Get token info from a connector instance using cached prices when available.""" + balances = [{"token": key, "units": value} for key, value in connector.get_all_balances().items() if + value != Decimal("0") and key not in settings.banned_tokens] + unique_tokens = [balance["token"] for balance in balances] + trading_pairs = [self.get_default_market(token, connector_name) for token in unique_tokens if "USD" not in token] + + # Try to get cached prices first, fallback to live prices if needed + prices_from_cache = {} + trading_pairs_need_update = [] + + if market_data_manager: + for trading_pair in trading_pairs: + try: + cached_price = market_data_manager.market_data_provider.get_rate(trading_pair) + if cached_price > 0: + prices_from_cache[trading_pair] = cached_price + else: + trading_pairs_need_update.append(trading_pair) + except Exception: + trading_pairs_need_update.append(trading_pair) + else: + trading_pairs_need_update = trading_pairs + + # Add new trading pairs to market data provider if they need updates + if trading_pairs_need_update: + self._initialize_rate_sources_for_pairs(connector_name, trading_pairs_need_update) + logger.info(f"Added {len(trading_pairs_need_update)} new trading pairs to market data provider: {trading_pairs_need_update}") + + # Get fresh prices for pairs not in cache or with stale/zero prices + fresh_prices = {} + if trading_pairs_need_update: + fresh_prices = await self._safe_get_last_traded_prices(connector, trading_pairs_need_update) + + # Combine cached and fresh prices + all_prices = {**prices_from_cache, **fresh_prices} + + tokens_info = [] + for balance in balances: + token = balance["token"] + if "USD" in token: + price = Decimal("1") + else: + market = self.get_default_market(balance["token"], connector_name) + price = Decimal(str(all_prices.get(market, 0))) + + tokens_info.append({ + "token": balance["token"], + "units": float(balance["units"]), + "price": float(price), + "value": float(price * balance["units"]), + "available_units": float(connector.get_available_balance(balance["token"])) + }) + return tokens_info + + async def _safe_get_last_traded_prices(self, connector, trading_pairs, timeout=10): + """Safely get last traded prices with timeout and error handling. Preserves previous prices on failure.""" try: - # TODO: Fix OKX connector to return the markets in Hummingbot format. last_traded = await asyncio.wait_for(connector.get_last_traded_prices(trading_pairs=trading_pairs), timeout=timeout) + + # Update cache with successful prices + for pair, price in last_traded.items(): + if price and price > 0: + self._last_known_prices[pair] = price + return last_traded except asyncio.TimeoutError: - logging.error(f"Timeout getting last traded prices for trading pairs {trading_pairs}") - return {pair: Decimal("0") for pair in trading_pairs} + logger.error(f"Timeout getting last traded prices for trading pairs {trading_pairs}") + return self._get_fallback_prices(trading_pairs) except Exception as e: - logging.error(f"Error getting last traded prices in connector {connector} for trading pairs {trading_pairs}: {e}") - return {pair: Decimal("0") for pair in trading_pairs} + logger.error(f"Error getting last traded prices in connector {connector} for trading pairs {trading_pairs}: {e}") + return self._get_fallback_prices(trading_pairs) + + def _get_fallback_prices(self, trading_pairs): + """Get fallback prices using cached values, only setting to 0 if no previous price exists.""" + fallback_prices = {} + for pair in trading_pairs: + if pair in self._last_known_prices: + fallback_prices[pair] = self._last_known_prices[pair] + logger.info(f"Using cached price {self._last_known_prices[pair]} for {pair}") + else: + fallback_prices[pair] = Decimal("0") + logger.warning(f"No cached price available for {pair}, using 0") + return fallback_prices - @staticmethod - def get_connector_config_map(connector_name: str): + def get_connector_config_map(self, connector_name: str): """ Get the connector config map for the specified connector. :param connector_name: The name of the connector. :return: The connector config map. """ - connector_config = BackendAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) - return [key for key in connector_config.hb_config.__fields__.keys() if key != "connector"] - - async def add_connector_keys(self, account_name: str, connector_name: str, keys: dict): - BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) - connector_config = BackendAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) - for key, value in keys.items(): - setattr(connector_config, key, value) - BackendAPISecurity.update_connector_keys(account_name, connector_config) - new_connector = self.get_connector(account_name, connector_name) - await new_connector._update_balances() - self.accounts[account_name][connector_name] = new_connector - await self.update_account_state() - await self.dump_account_state() + return self.connector_manager.get_connector_config_map(connector_name) - def get_connector(self, account_name: str, connector_name: str): + async def add_credentials(self, account_name: str, connector_name: str, credentials: dict): """ - Get the connector object for the specified account and connector. + Add or update connector credentials and initialize the connector with validation. + :param account_name: The name of the account. :param connector_name: The name of the connector. - :return: The connector object. - """ - BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) - client_config_map = ClientConfigAdapter(ClientConfigMap()) - conn_setting = AllConnectorSettings.get_connector_settings()[connector_name] - keys = BackendAPISecurity.api_keys(connector_name) - read_only_config = ReadOnlyClientConfigAdapter.lock_config(client_config_map) - init_params = conn_setting.conn_init_parameters( - trading_pairs=[], - trading_required=True, - api_keys=keys, - client_config_map=read_only_config, - ) - connector_class = get_connector_class(connector_name) - connector = connector_class(**init_params) - return connector + :param credentials: Dictionary containing the connector credentials. + :raises Exception: If credentials are invalid or connector cannot be initialized. + """ + try: + # Update the connector keys (this saves the credentials to file and validates them) + connector = await self.connector_manager.update_connector_keys(account_name, connector_name, credentials) + + # Initialize price tracking for this connector's tokens if market data manager is available + if self.market_data_feed_manager: + await self._initialize_price_tracking(account_name, connector_name, connector) + + await self.update_account_state() + except Exception as e: + logger.error(f"Error adding connector credentials for account {account_name}: {e}") + await self.delete_credentials(account_name, connector_name) + raise e @staticmethod def list_accounts(): @@ -287,33 +380,40 @@ def list_accounts(): List all the accounts that are connected to the trading system. :return: List of accounts. """ - return file_system.list_folders('credentials') + return fs_util.list_folders('credentials') - def list_credentials(self, account_name: str): + @staticmethod + def list_credentials(account_name: str): """ List all the credentials that are connected to the specified account. :param account_name: The name of the account. :return: List of credentials. """ try: - return [file for file in file_system.list_files(f'credentials/{account_name}/connectors') if + return [file for file in fs_util.list_files(f'credentials/{account_name}/connectors') if file.endswith('.yml')] except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) - def delete_credentials(self, account_name: str, connector_name: str): + async def delete_credentials(self, account_name: str, connector_name: str): """ Delete the credentials of the specified connector for the specified account. :param account_name: :param connector_name: :return: """ - if file_system.path_exists(f"credentials/{account_name}/connectors/{connector_name}.yml"): - file_system.delete_file(directory=f"credentials/{account_name}/connectors", file_name=f"{connector_name}.yml") - if connector_name in self.accounts[account_name]: - self.accounts[account_name].pop(connector_name) - if connector_name in self.accounts_state[account_name]: + if fs_util.path_exists(f"credentials/{account_name}/connectors/{connector_name}.yml"): + fs_util.delete_file(directory=f"credentials/{account_name}/connectors", file_name=f"{connector_name}.yml") + + # Stop the connector if it's running + await self.connector_manager.stop_connector(account_name, connector_name) + + # Remove from account state + if account_name in self.accounts_state and connector_name in self.accounts_state[account_name]: self.accounts_state[account_name].pop(connector_name) + + # Clear the connector from cache + self.connector_manager.clear_cache(account_name, connector_name) def add_account(self, account_name: str): """ @@ -321,22 +421,871 @@ def add_account(self, account_name: str): :param account_name: :return: """ - if account_name in self.accounts: + # Check if account already exists by looking at folders + if account_name in self.list_accounts(): raise HTTPException(status_code=400, detail="Account already exists.") + files_to_copy = ["conf_client.yml", "conf_fee_overrides.yml", "hummingbot_logs.yml", ".password_verification"] - file_system.create_folder('credentials', account_name) - file_system.create_folder(f'credentials/{account_name}', "connectors") + fs_util.create_folder('credentials', account_name) + fs_util.create_folder(f'credentials/{account_name}', "connectors") for file in files_to_copy: - file_system.copy_file(f"credentials/master_account/{file}", f"credentials/{account_name}/{file}") - self.accounts[account_name] = {} + fs_util.copy_file(f"credentials/master_account/{file}", f"credentials/{account_name}/{file}") + + # Initialize account state self.accounts_state[account_name] = {} - def delete_account(self, account_name: str): + async def delete_account(self, account_name: str): """ Delete the specified account. :param account_name: :return: """ - file_system.delete_folder('credentials', account_name) - self.accounts.pop(account_name) - self.accounts_state.pop(account_name) + # Stop all connectors for this account + for connector_name in self.connector_manager.list_account_connectors(account_name): + await self.connector_manager.stop_connector(account_name, connector_name) + + # Delete account folder + fs_util.delete_folder('credentials', account_name) + + # Remove from account state + if account_name in self.accounts_state: + self.accounts_state.pop(account_name) + + # Clear all connectors for this account from cache + self.connector_manager.clear_cache(account_name) + + async def get_account_current_state(self, account_name: str) -> Dict[str, List[Dict]]: + """ + Get current state for a specific account from database. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_current_state(account_name) + except Exception as e: + logger.error(f"Error getting account current state: {e}") + # Fallback to in-memory state + return self.accounts_state.get(account_name, {}) + + async def get_account_state_history(self, + account_name: str, + limit: Optional[int] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None): + """ + Get historical state for a specific account with pagination. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_state_history( + account_name=account_name, + limit=limit, + cursor=cursor, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error getting account state history: {e}") + return [], None, False + + async def get_connector_current_state(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get current state for a specific connector. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_connector_current_state(account_name, connector_name) + except Exception as e: + logger.error(f"Error getting connector current state: {e}") + # Fallback to in-memory state + return self.accounts_state.get(account_name, {}).get(connector_name, []) + + async def get_connector_state_history(self, + account_name: str, + connector_name: str, + limit: Optional[int] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None): + """ + Get historical state for a specific connector with pagination. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_state_history( + account_name=account_name, + connector_name=connector_name, + limit=limit, + cursor=cursor, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error getting connector state history: {e}") + return [], None, False + + async def get_all_unique_tokens(self) -> List[str]: + """ + Get all unique tokens across all accounts and connectors. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_all_unique_tokens() + except Exception as e: + logger.error(f"Error getting unique tokens: {e}") + # Fallback to in-memory state + tokens = set() + for account_data in self.accounts_state.values(): + for connector_data in account_data.values(): + for token_info in connector_data: + tokens.add(token_info.get("token")) + return sorted(list(tokens)) + + async def get_token_current_state(self, token: str) -> List[Dict]: + """ + Get current state of a specific token across all accounts. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_token_current_state(token) + except Exception as e: + logger.error(f"Error getting token current state: {e}") + return [] + + async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict[str, any]: + """ + Get total portfolio value, optionally filtered by account. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_portfolio_value(account_name) + except Exception as e: + logger.error(f"Error getting portfolio value: {e}") + # Fallback to in-memory calculation + portfolio = {"accounts": {}, "total_value": 0} + + accounts_to_process = [account_name] if account_name else self.accounts_state.keys() + + for acc_name in accounts_to_process: + account_value = 0 + if acc_name in self.accounts_state: + for connector_data in self.accounts_state[acc_name].values(): + for token_info in connector_data: + account_value += token_info.get("value", 0) + portfolio["accounts"][acc_name] = account_value + portfolio["total_value"] += account_value + + return portfolio + + def get_portfolio_distribution(self, account_name: Optional[str] = None) -> Dict[str, any]: + """ + Get portfolio distribution by tokens with percentages. + """ + try: + # Get accounts to process + accounts_to_process = [account_name] if account_name else list(self.accounts_state.keys()) + + # Aggregate all tokens across accounts and connectors + token_values = {} + total_value = 0 + + for acc_name in accounts_to_process: + if acc_name in self.accounts_state: + for connector_name, connector_data in self.accounts_state[acc_name].items(): + for token_info in connector_data: + token = token_info.get("token", "") + value = token_info.get("value", 0) + + if token not in token_values: + token_values[token] = { + "token": token, + "total_value": 0, + "total_units": 0, + "accounts": {} + } + + token_values[token]["total_value"] += value + token_values[token]["total_units"] += token_info.get("units", 0) + total_value += value + + # Track by account + if acc_name not in token_values[token]["accounts"]: + token_values[token]["accounts"][acc_name] = { + "value": 0, + "units": 0, + "connectors": {} + } + + token_values[token]["accounts"][acc_name]["value"] += value + token_values[token]["accounts"][acc_name]["units"] += token_info.get("units", 0) + + # Track by connector within account + if connector_name not in token_values[token]["accounts"][acc_name]["connectors"]: + token_values[token]["accounts"][acc_name]["connectors"][connector_name] = { + "value": 0, + "units": 0 + } + + token_values[token]["accounts"][acc_name]["connectors"][connector_name]["value"] += value + token_values[token]["accounts"][acc_name]["connectors"][connector_name]["units"] += token_info.get("units", 0) + + # Calculate percentages + distribution = [] + for token_data in token_values.values(): + percentage = (token_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + token_dist = { + "token": token_data["token"], + "total_value": round(token_data["total_value"], 6), + "total_units": token_data["total_units"], + "percentage": round(percentage, 4), + "accounts": {} + } + + # Add account-level percentages + for acc_name, acc_data in token_data["accounts"].items(): + acc_percentage = (acc_data["value"] / total_value * 100) if total_value > 0 else 0 + token_dist["accounts"][acc_name] = { + "value": round(acc_data["value"], 6), + "units": acc_data["units"], + "percentage": round(acc_percentage, 4), + "connectors": {} + } + + # Add connector-level data + for conn_name, conn_data in acc_data["connectors"].items(): + token_dist["accounts"][acc_name]["connectors"][conn_name] = { + "value": round(conn_data["value"], 6), + "units": conn_data["units"] + } + + distribution.append(token_dist) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "token_count": len(distribution), + "distribution": distribution, + "account_filter": account_name if account_name else "all_accounts" + } + + except Exception as e: + logger.error(f"Error calculating portfolio distribution: {e}") + return { + "total_portfolio_value": 0, + "token_count": 0, + "distribution": [], + "account_filter": account_name if account_name else "all_accounts", + "error": str(e) + } + + def get_account_distribution(self) -> Dict[str, any]: + """ + Get portfolio distribution by accounts with percentages. + """ + try: + account_values = {} + total_value = 0 + + for acc_name, account_data in self.accounts_state.items(): + account_value = 0 + connector_values = {} + + for connector_name, connector_data in account_data.items(): + connector_value = 0 + for token_info in connector_data: + value = token_info.get("value", 0) + connector_value += value + account_value += value + + connector_values[connector_name] = round(connector_value, 6) + + account_values[acc_name] = { + "total_value": round(account_value, 6), + "connectors": connector_values + } + total_value += account_value + + # Calculate percentages + distribution = [] + for acc_name, acc_data in account_values.items(): + percentage = (acc_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + connector_dist = {} + for conn_name, conn_value in acc_data["connectors"].items(): + conn_percentage = (conn_value / total_value * 100) if total_value > 0 else 0 + connector_dist[conn_name] = { + "value": conn_value, + "percentage": round(conn_percentage, 4) + } + + distribution.append({ + "account": acc_name, + "total_value": acc_data["total_value"], + "percentage": round(percentage, 4), + "connectors": connector_dist + }) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "account_count": len(distribution), + "distribution": distribution + } + + except Exception as e: + logger.error(f"Error calculating account distribution: {e}") + return { + "total_portfolio_value": 0, + "account_count": 0, + "distribution": [], + "error": str(e) + } + + async def place_trade(self, account_name: str, connector_name: str, trading_pair: str, + trade_type: TradeType, amount: Decimal, order_type: OrderType = OrderType.LIMIT, + price: Optional[Decimal] = None, position_action: PositionAction = PositionAction.OPEN, + market_data_manager: Optional[MarketDataFeedManager] = None) -> str: + """ + Place a trade using the specified account and connector. + + Args: + account_name: Name of the account to trade with + connector_name: Name of the connector/exchange + trading_pair: Trading pair (e.g., BTC-USDT) + trade_type: "BUY" or "SELL" + amount: Amount to trade + order_type: "LIMIT", "MARKET", or "LIMIT_MAKER" + price: Price for limit orders (required for LIMIT and LIMIT_MAKER) + position_action: Position action for perpetual contracts (OPEN/CLOSE) + market_data_manager: Market data manager for price fetching + + Returns: + Client order ID assigned by the connector + + Raises: + HTTPException: If account, connector not found, or trade fails + """ + # Validate account exists + if account_name not in self.list_accounts(): + raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") + + # Validate connector exists for account + if not self.connector_manager.is_connector_initialized(account_name, connector_name): + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found for account '{account_name}'") + + # Get the connector instance + connector = await self.connector_manager.get_connector(account_name, connector_name) + + # Validate price for limit orders + if order_type in [OrderType.LIMIT, OrderType.LIMIT_MAKER] and price is None: + raise HTTPException(status_code=400, detail="Price is required for LIMIT and LIMIT_MAKER orders") + + # Check if trading rules are loaded + if not connector.trading_rules: + raise HTTPException( + status_code=503, + detail=f"Trading rules not yet loaded for {connector_name}. Please try again in a moment." + ) + + # Validate trading pair and get trading rule + if trading_pair not in connector.trading_rules: + available_pairs = list(connector.trading_rules.keys())[:10] # Show first 10 + more_text = f" (and {len(connector.trading_rules) - 10} more)" if len(connector.trading_rules) > 10 else "" + raise HTTPException( + status_code=400, + detail=f"Trading pair '{trading_pair}' not supported on {connector_name}. " + f"Available pairs: {available_pairs}{more_text}" + ) + + trading_rule = connector.trading_rules[trading_pair] + + # Validate order type is supported + if order_type not in connector.supported_order_types(): + supported_types = [ot.name for ot in connector.supported_order_types()] + raise HTTPException(status_code=400, detail=f"Order type '{order_type.name}' not supported. Supported types: {supported_types}") + + # Quantize amount according to trading rules + quantized_amount = connector.quantize_order_amount(trading_pair, amount) + + # Validate minimum order size + if quantized_amount < trading_rule.min_order_size: + raise HTTPException( + status_code=400, + detail=f"Order amount {quantized_amount} is below minimum order size {trading_rule.min_order_size} for {trading_pair}" + ) + + # Calculate and validate notional size + if order_type in [OrderType.LIMIT, OrderType.LIMIT_MAKER]: + quantized_price = connector.quantize_order_price(trading_pair, price) + notional_size = quantized_price * quantized_amount + else: + # For market orders without price, get current market price for validation + if market_data_manager: + try: + prices = await market_data_manager.get_prices(connector_name, [trading_pair]) + if trading_pair in prices and "error" not in prices: + price = Decimal(str(prices[trading_pair])) + except Exception as e: + logger.error(f"Error getting market price for {trading_pair}: {e}") + notional_size = price * quantized_amount + + if notional_size < trading_rule.min_notional_size: + raise HTTPException( + status_code=400, + detail=f"Order notional value {notional_size} is below minimum notional size {trading_rule.min_notional_size} for {trading_pair}. " + f"Increase the amount or price to meet the minimum requirement." + ) + + + + try: + # Place the order using the connector with quantized values + # (position_action will be ignored by non-perpetual connectors) + if trade_type == TradeType.BUY: + order_id = connector.buy( + trading_pair=trading_pair, + amount=quantized_amount, + order_type=order_type, + price=price or Decimal("1"), + position_action=position_action + ) + else: + order_id = connector.sell( + trading_pair=trading_pair, + amount=quantized_amount, + order_type=order_type, + price=price or Decimal("1"), + position_action=position_action + ) + + logger.info(f"Placed {trade_type} order for {amount} {trading_pair} on {connector_name} (Account: {account_name}). Order ID: {order_id}") + return order_id + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + logger.error(f"Failed to place {trade_type} order: {e}") + raise HTTPException(status_code=500, detail=f"Failed to place trade: {str(e)}") + + async def get_connector_instance(self, account_name: str, connector_name: str): + """ + Get a connector instance for direct access. + + Args: + account_name: Name of the account + connector_name: Name of the connector + + Returns: + Connector instance + + Raises: + HTTPException: If account or connector not found + """ + if account_name not in self.list_accounts(): + raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") + + # Check if connector credentials exist + available_credentials = self.connector_manager.list_available_credentials(account_name) + if connector_name not in available_credentials: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found for account '{account_name}'") + + return await self.connector_manager.get_connector(account_name, connector_name) + + async def get_active_orders(self, account_name: str, connector_name: str) -> Dict[str, any]: + """ + Get active orders for a specific connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector + + Returns: + Dictionary of active orders + """ + connector = await self.get_connector_instance(account_name, connector_name) + return {order_id: order.to_json() for order_id, order in connector.in_flight_orders.items()} + + async def cancel_order(self, account_name: str, connector_name: str, client_order_id: str) -> str: + """ + Cancel an active order. + + Args: + account_name: Name of the account + connector_name: Name of the connector + client_order_id: Client order ID to cancel + + Returns: + Client order ID that was cancelled + + Raises: + HTTPException: 404 if order not found, 500 if cancellation fails + """ + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if order exists in in-flight orders + if client_order_id not in connector.in_flight_orders: + raise HTTPException(status_code=404, detail=f"Order '{client_order_id}' not found in active orders") + + try: + result = connector.cancel(trading_pair="NA", client_order_id=client_order_id) + logger.info(f"Initiated cancellation for order {client_order_id} on {connector_name} (Account: {account_name})") + return result + except Exception as e: + logger.error(f"Failed to initiate cancellation for order {client_order_id}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to initiate order cancellation: {str(e)}") + + async def set_leverage(self, account_name: str, connector_name: str, + trading_pair: str, leverage: int) -> Dict[str, str]: + """ + Set leverage for a specific trading pair on a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + trading_pair: Trading pair to set leverage for + leverage: Leverage value (typically 1-125) + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + # Validate this is a perpetual connector + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if connector has leverage functionality + if not hasattr(connector, '_execute_set_leverage'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support leverage setting") + + try: + await connector._execute_set_leverage(trading_pair, leverage) + message = f"Leverage for {trading_pair} set to {leverage} on {connector_name}" + logger.info(f"Set leverage for {trading_pair} to {leverage} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set leverage for {trading_pair} to {leverage}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set leverage: {str(e)}") + + async def set_position_mode(self, account_name: str, connector_name: str, + position_mode: PositionMode) -> Dict[str, str]: + """ + Set position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + position_mode: PositionMode.HEDGE or PositionMode.ONEWAY + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + # Validate this is a perpetual connector + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if the requested position mode is supported + supported_modes = connector.supported_position_modes() + if position_mode not in supported_modes: + supported_values = [mode.value for mode in supported_modes] + raise HTTPException( + status_code=400, + detail=f"Position mode '{position_mode.value}' not supported. Supported modes: {supported_values}" + ) + + try: + # Try to call the method - it might be sync or async + result = connector.set_position_mode(position_mode) + # If it's a coroutine, await it + if asyncio.iscoroutine(result): + await result + + message = f"Position mode set to {position_mode.value} on {connector_name}" + logger.info(f"Set position mode to {position_mode.value} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set position mode to {position_mode.value}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set position mode: {str(e)}") + + async def get_position_mode(self, account_name: str, connector_name: str) -> Dict[str, str]: + """ + Get current position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Dictionary with current position mode + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + # Validate this is a perpetual connector + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if connector has position mode functionality + if not hasattr(connector, 'position_mode'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position mode") + + try: + current_mode = connector.position_mode + return { + "position_mode": current_mode.value if current_mode else "UNKNOWN", + "connector": connector_name, + "account": account_name + } + + except Exception as e: + logger.error(f"Failed to get position mode: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get position mode: {str(e)}") + + async def get_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, status: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get order history using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + orders = await order_repo.get_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + status=status, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [order_repo.to_dict(order) for order in orders] + except Exception as e: + logger.error(f"Error getting orders: {e}") + return [] + + async def get_active_orders_history(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None) -> List[Dict]: + """Get active orders from database using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + orders = await order_repo.get_active_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + return [order_repo.to_dict(order) for order in orders] + except Exception as e: + logger.error(f"Error getting active orders: {e}") + return [] + + async def get_orders_summary(self, account_name: Optional[str] = None, start_time: Optional[int] = None, + end_time: Optional[int] = None) -> Dict: + """Get order summary statistics using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + return await order_repo.get_orders_summary( + account_name=account_name, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error getting orders summary: {e}") + return { + "total_orders": 0, + "filled_orders": 0, + "cancelled_orders": 0, + "failed_orders": 0, + "active_orders": 0, + "fill_rate": 0, + } + + async def get_trades(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, trade_type: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get trade history using TradeRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + trade_repo = TradeRepository(session) + trade_order_pairs = await trade_repo.get_trades_with_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + trade_type=trade_type, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [trade_repo.to_dict(trade, order) for trade, order in trade_order_pairs] + except Exception as e: + logger.error(f"Error getting trades: {e}") + return [] + + async def get_account_positions(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get current positions for a specific perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + List of position dictionaries + + Raises: + HTTPException: If account/connector not found or not perpetual + """ + # Validate this is a perpetual connector + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if connector has account_positions property + if not hasattr(connector, 'account_positions'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position tracking") + + try: + # Force position update to ensure current market prices are used + await connector._update_positions() + + positions = [] + raw_positions = connector.account_positions + + for trading_pair, position_info in raw_positions.items(): + # Convert position data to dict format + position_dict = { + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": position_info.trading_pair, + "side": position_info.position_side.name if hasattr(position_info, 'position_side') else "UNKNOWN", + "amount": float(position_info.amount) if hasattr(position_info, 'amount') else 0.0, + "entry_price": float(position_info.entry_price) if hasattr(position_info, 'entry_price') else None, + "unrealized_pnl": float(position_info.unrealized_pnl) if hasattr(position_info, 'unrealized_pnl') else None, + "leverage": float(position_info.leverage) if hasattr(position_info, 'leverage') else None, + } + + # Only include positions with non-zero amounts + if position_dict["amount"] != 0: + positions.append(position_dict) + + return positions + + except Exception as e: + logger.error(f"Failed to get positions for {connector_name}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get positions: {str(e)}") + + async def get_funding_payments(self, account_name: str, connector_name: str = None, + trading_pair: str = None, limit: int = 100) -> List[Dict]: + """ + Get funding payment history for an account. + + Args: + account_name: Name of the account + connector_name: Optional connector name filter + trading_pair: Optional trading pair filter + limit: Maximum number of records to return + + Returns: + List of funding payment dictionaries + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + funding_repo = FundingRepository(session) + funding_payments = await funding_repo.get_funding_payments( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + limit=limit + ) + return [funding_repo.to_dict(payment) for payment in funding_payments] + + except Exception as e: + logger.error(f"Error getting funding payments: {e}") + return [] + + async def get_total_funding_fees(self, account_name: str, connector_name: str, + trading_pair: str) -> Dict: + """ + Get total funding fees for a specific trading pair. + + Args: + account_name: Name of the account + connector_name: Name of the connector + trading_pair: Trading pair to get fees for + + Returns: + Dictionary with total funding fees information + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + funding_repo = FundingRepository(session) + return await funding_repo.get_total_funding_fees( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + + except Exception as e: + logger.error(f"Error getting total funding fees: {e}") + return { + "total_funding_fees": 0, + "payment_count": 0, + "fee_currency": None, + "error": str(e) + } diff --git a/services/bots_orchestrator.py b/services/bots_orchestrator.py index b65293a5..b4449e95 100644 --- a/services/bots_orchestrator.py +++ b/services/bots_orchestrator.py @@ -1,71 +1,54 @@ import asyncio -from collections import deque +import logging from typing import Optional +import re import docker -from hbotrc import BotCommands -from hbotrc.listener import BotListener -from hbotrc.spec import TopicSpecs - - -class HummingbotPerformanceListener(BotListener): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - topic_prefix = TopicSpecs.PREFIX.format( - namespace=self._ns, - instance_id=self._bot_id - ) - self._performance_topic = f'{topic_prefix}/performance' - self._bot_performance = {} - self._bot_error_logs = deque(maxlen=100) - self._bot_general_logs = deque(maxlen=100) - self.performance_report_sub = None - - def get_bot_performance(self): - return self._bot_performance - - def get_bot_error_logs(self): - return list(self._bot_error_logs) - - def get_bot_general_logs(self): - return list(self._bot_general_logs) - - def _init_endpoints(self): - super()._init_endpoints() - self.performance_report_sub = self.create_subscriber(topic=self._performance_topic, - on_message=self._update_bot_performance) - - def _update_bot_performance(self, msg): - for controller_id, performance_report in msg.items(): - self._bot_performance[controller_id] = performance_report - - def _on_log(self, log): - if log.level_name == "ERROR": - self._bot_error_logs.append(log) - else: - self._bot_general_logs.append(log) - def stop(self): - super().stop() - self._bot_performance = {} +from utils.mqtt_manager import MQTTManager + +logger = logging.getLogger(__name__) + +# HummingbotPerformanceListener class is no longer needed +# All functionality is now handled by MQTTManager + + +class BotsOrchestrator: + """Orchestrates Hummingbot instances using Docker and MQTT communication.""" -class BotsManager: def __init__(self, broker_host, broker_port, broker_username, broker_password): self.broker_host = broker_host self.broker_port = broker_port self.broker_username = broker_username self.broker_password = broker_password + + # Initialize Docker client self.docker_client = docker.from_env() + + # Initialize MQTT manager + self.mqtt_manager = MQTTManager(host=broker_host, port=broker_port, username=broker_username, password=broker_password) + + # Active bots tracking self.active_bots = {} self._update_bots_task: Optional[asyncio.Task] = None + + # Track bots that are currently being stopped and archived + self.stopping_bots = set() + + # MQTT manager will be started asynchronously later @staticmethod def hummingbot_containers_fiter(container): + """Filter for Hummingbot containers based on image name pattern.""" try: - return "hummingbot" in container.name and "broker" not in container.name + # Get the image name (first tag if available, otherwise the image ID) + image_name = container.image.tags[0] if container.image.tags else str(container.image) + pattern = r'.+/hummingbot:' + return bool(re.match(pattern, image_name)) except Exception: return False + async def get_active_containers(self): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._sync_get_active_containers) @@ -74,64 +57,168 @@ def _sync_get_active_containers(self): return [ container.name for container in self.docker_client.containers.list() - if container.status == 'running' and self.hummingbot_containers_fiter(container) + if container.status == "running" and self.hummingbot_containers_fiter(container) ] - def start_update_active_bots_loop(self): - self._update_bots_task = asyncio.create_task(self.update_active_bots()) + def start(self): + """Start the loop that monitors active bots.""" + # Start MQTT manager and update loop in async context + self._update_bots_task = asyncio.create_task(self._start_async()) + + async def _start_async(self): + """Start MQTT manager and update loop asynchronously.""" + logger.info("Starting MQTT manager...") + await self.mqtt_manager.start() + + # Then start the update loop + await self.update_active_bots() - def stop_update_active_bots_loop(self): + def stop(self): + """Stop the active bots monitoring loop.""" if self._update_bots_task: self._update_bots_task.cancel() self._update_bots_task = None - async def update_active_bots(self, sleep_time=1): + # Stop MQTT manager asynchronously + asyncio.create_task(self.mqtt_manager.stop()) + + async def update_active_bots(self, sleep_time=1.0): + """Monitor and update active bots list using both Docker and MQTT discovery.""" while True: - active_hbot_containers = await self.get_active_containers() - # Remove bots that are no longer active - for bot in list(self.active_bots): - if bot not in active_hbot_containers: - del self.active_bots[bot] - - # Add new bots or update existing ones - for bot in active_hbot_containers: - if bot not in self.active_bots: - hbot_listener = HummingbotPerformanceListener(host=self.broker_host, port=self.broker_port, - username=self.broker_username, - password=self.broker_password, - bot_id=bot) - hbot_listener.start() - self.active_bots[bot] = { - "bot_name": bot, - "broker_client": BotCommands(host=self.broker_host, port=self.broker_port, - username=self.broker_username, password=self.broker_password, - bot_id=bot), - "broker_listener": hbot_listener, - } + try: + # Get bots from Docker containers + docker_bots = await self.get_active_containers() + + # Get bots from MQTT messages (auto-discovered) + mqtt_bots = self.mqtt_manager.get_discovered_bots(timeout_seconds=30) # 30 second timeout + + # Combine both sources + all_active_bots = set([bot for bot in docker_bots + mqtt_bots if not self.is_bot_stopping(bot)]) + + # Remove bots that are no longer active + for bot_name in list(self.active_bots): + if bot_name not in all_active_bots: + self.mqtt_manager.clear_bot_data(bot_name) + del self.active_bots[bot_name] + + # Add new bots + for bot_name in all_active_bots: + if bot_name not in self.active_bots: + self.active_bots[bot_name] = { + "bot_name": bot_name, + "status": "connected", + "source": "docker" if bot_name in docker_bots else "mqtt", + } + # Subscribe to this specific bot's topics + await self.mqtt_manager.subscribe_to_bot(bot_name) + + except Exception as e: + logger.error(f"Error in update_active_bots: {e}", exc_info=True) + await asyncio.sleep(sleep_time) # Interact with a specific bot - def start_bot(self, bot_name, **kwargs): - if bot_name in self.active_bots: - self.active_bots[bot_name]["broker_listener"].start() - return self.active_bots[bot_name]["broker_client"].start(**kwargs) + async def start_bot(self, bot_name, **kwargs): + """ + Start a bot with optional script. + Maintains backward compatibility with kwargs. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create StartCommandMessage.Request format + data = { + "log_level": kwargs.get("log_level"), + "script": kwargs.get("script"), + "conf": kwargs.get("conf"), + "is_quickstart": kwargs.get("is_quickstart", False), + "async_backend": kwargs.get("async_backend", True), + } - def stop_bot(self, bot_name, **kwargs): - if bot_name in self.active_bots: - self.active_bots[bot_name]["broker_listener"].stop() - return self.active_bots[bot_name]["broker_client"].stop(**kwargs) + success = await self.mqtt_manager.publish_command(bot_name, "start", data) + return {"success": success} - def import_strategy_for_bot(self, bot_name, strategy, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].import_strategy(strategy, **kwargs) + async def stop_bot(self, bot_name, **kwargs): + """ + Stop a bot. + Maintains backward compatibility with kwargs. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} - def configure_bot(self, bot_name, params, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].config(params, **kwargs) + # Create StopCommandMessage.Request format + data = { + "skip_order_cancellation": kwargs.get("skip_order_cancellation", False), + "async_backend": kwargs.get("async_backend", True), + } - def get_bot_history(self, bot_name, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].history(**kwargs) + success = await self.mqtt_manager.publish_command(bot_name, "stop", data) + + # Clear performance data after stop command to immediately reflect stopped status + if success: + self.mqtt_manager.clear_bot_performance(bot_name) + + return {"success": success} + + async def import_strategy_for_bot(self, bot_name, strategy, **kwargs): + """ + Import a strategy configuration for a bot. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create ImportCommandMessage.Request format + data = {"strategy": strategy} + success = await self.mqtt_manager.publish_command(bot_name, "import_strategy", data) + return {"success": success} + + async def configure_bot(self, bot_name, params, **kwargs): + """ + Configure bot parameters. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create ConfigCommandMessage.Request format + data = {"params": params} + success = await self.mqtt_manager.publish_command(bot_name, "config", data) + return {"success": success} + + async def get_bot_history(self, bot_name, **kwargs): + """ + Request bot trading history and wait for the response. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create HistoryCommandMessage.Request format + data = { + "days": kwargs.get("days", 0), + "verbose": kwargs.get("verbose", False), + "precision": kwargs.get("precision"), + "async_backend": kwargs.get("async_backend", False), + } + + # Use the new RPC method to wait for response + timeout = kwargs.get("timeout", 30.0) # Default 30 second timeout + response = await self.mqtt_manager.publish_command_and_wait(bot_name, "history", data, timeout=timeout) + + if response is None: + return { + "success": False, + "message": f"No response received from {bot_name} within {timeout} seconds", + "timeout": True, + } + + return {"success": True, "data": response} @staticmethod def determine_controller_performance(controllers_performance): @@ -140,10 +227,7 @@ def determine_controller_performance(controllers_performance): try: # Check if all the metrics are numeric _ = sum(metric for key, metric in performance.items() if key not in ("positions_summary", "close_type_counts")) - cleaned_performance[controller] = { - "status": "running", - "performance": performance - } + cleaned_performance[controller] = {"status": "running", "performance": performance} except Exception as e: cleaned_performance[controller] = { "status": "error", @@ -152,28 +236,73 @@ def determine_controller_performance(controllers_performance): return cleaned_performance def get_all_bots_status(self): + # TODO: improve logic of bots state management + """Get status information for all active bots.""" all_bots_status = {} - for bot in self.active_bots: - all_bots_status[bot] = self.get_bot_status(bot) + for bot in [bot for bot in self.active_bots if not self.is_bot_stopping(bot)]: + status = self.get_bot_status(bot) + status["source"] = self.active_bots[bot].get("source", "unknown") + all_bots_status[bot] = status return all_bots_status def get_bot_status(self, bot_name): - if bot_name in self.active_bots: - try: - broker_listener = self.active_bots[bot_name]["broker_listener"] - controllers_performance = broker_listener.get_bot_performance() - performance = self.determine_controller_performance(controllers_performance) - error_logs = broker_listener.get_bot_error_logs() - general_logs = broker_listener.get_bot_general_logs() - status = "running" if len(performance) > 0 else "stopped" - return { - "status": status, - "performance": performance, - "error_logs": error_logs, - "general_logs": general_logs - } - except Exception as e: + """ + Get status information for a specific bot. + """ + if bot_name not in self.active_bots: + return {"status": "not_found", "error": f"Bot {bot_name} not found"} + + try: + # Check if bot is currently being stopped and archived + if bot_name in self.stopping_bots: return { - "status": "error", - "error": str(e) + "status": "stopping", + "message": "Bot is currently being stopped and archived", + "performance": {}, + "error_logs": [], + "general_logs": [], + "recently_active": False, } + + # Get data from MQTT manager + controllers_performance = self.mqtt_manager.get_bot_performance(bot_name) + performance = self.determine_controller_performance(controllers_performance) + error_logs = self.mqtt_manager.get_bot_error_logs(bot_name) + general_logs = self.mqtt_manager.get_bot_logs(bot_name) + + # Check if bot has sent recent messages (within last 30 seconds) + discovered_bots = self.mqtt_manager.get_discovered_bots(timeout_seconds=30) + recently_active = bot_name in discovered_bots + + # Determine status based on performance data and recent activity + if len(performance) > 0 and recently_active: + status = "running" + elif len(performance) > 0 and not recently_active: + status = "idle" # Has performance data but no recent activity + else: + status = "stopped" + + return { + "status": status, + "performance": performance, + "error_logs": error_logs, + "general_logs": general_logs, + "recently_active": recently_active, + } + except Exception as e: + return {"status": "error", "error": str(e)} + + def set_bot_stopping(self, bot_name: str): + """Mark a bot as currently being stopped and archived.""" + self.stopping_bots.add(bot_name) + logger.info(f"Marked bot {bot_name} as stopping") + + def clear_bot_stopping(self, bot_name: str): + """Clear the stopping status for a bot.""" + self.stopping_bots.discard(bot_name) + logger.info(f"Cleared stopping status for bot {bot_name}") + + def is_bot_stopping(self, bot_name: str) -> bool: + """Check if a bot is currently being stopped.""" + return bot_name in self.stopping_bots + diff --git a/services/docker_service.py b/services/docker_service.py index 2232893c..69ac32e9 100644 --- a/services/docker_service.py +++ b/services/docker_service.py @@ -1,31 +1,65 @@ import logging import os import shutil +import time +import threading +from typing import Dict + +# Create module-specific logger +logger = logging.getLogger(__name__) import docker from docker.errors import DockerException from docker.types import LogConfig -from models import HummingbotInstanceConfig -from utils.file_system import FileSystemUtil - -file_system = FileSystemUtil() +from config import settings +from models import V2ScriptDeployment +from utils.file_system import fs_util -class DockerManager: +class DockerService: + # Class-level configuration for cleanup + PULL_STATUS_MAX_AGE_SECONDS = 3600 # Keep status for 1 hour + PULL_STATUS_MAX_ENTRIES = 100 # Maximum number of entries to keep + CLEANUP_INTERVAL_SECONDS = 300 # Run cleanup every 5 minutes + def __init__(self): self.SOURCE_PATH = os.getcwd() + self._pull_status: Dict[str, Dict] = {} + self._cleanup_thread = None + self._stop_cleanup = threading.Event() + try: self.client = docker.from_env() + # Start background cleanup thread + self._start_cleanup_thread() except DockerException as e: - logging.error(f"It was not possible to connect to Docker. Please make sure Docker is running. Error: {e}") + logger.error(f"It was not possible to connect to Docker. Please make sure Docker is running. Error: {e}") - def get_active_containers(self): + def get_active_containers(self, name_filter: str = None): try: - containers_info = [{"id": container.id, "name": container.name, "status": container.status} for - container in self.client.containers.list(filters={"status": "running"}) if - "hummingbot" in container.name and "broker" not in container.name] - return {"active_instances": containers_info} + all_containers = self.client.containers.list(filters={"status": "running"}) + if name_filter: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers if name_filter.lower() in container.name.lower() + ] + else: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers + ] + return containers_info except DockerException as e: return str(e) @@ -38,16 +72,42 @@ def get_available_images(self): def pull_image(self, image_name): try: - self.client.images.pull(image_name) + return self.client.images.pull(image_name) except DockerException as e: return str(e) - def get_exited_containers(self): + def pull_image_sync(self, image_name): + """Synchronous pull operation for background tasks""" try: - containers_info = [{"id": container.id, "name": container.name, "status": container.status} for - container in self.client.containers.list(filters={"status": "exited"}) if - "hummingbot" in container.name and "broker" not in container.name] - return {"exited_instances": containers_info} + result = self.client.images.pull(image_name) + return {"success": True, "image": image_name, "result": str(result)} + except DockerException as e: + return {"success": False, "error": str(e)} + + def get_exited_containers(self, name_filter: str = None): + try: + all_containers = self.client.containers.list(filters={"status": "exited"}, all=True) + if name_filter: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers if name_filter.lower() in container.name.lower() + ] + else: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers + ] + return containers_info except DockerException as e: return str(e) @@ -78,6 +138,21 @@ def start_container(self, container_name): except DockerException as e: return str(e) + def get_container_status(self, container_name): + """Get the status of a container""" + try: + container = self.client.containers.get(container_name) + return { + "success": True, + "state": { + "status": container.status, + "running": container.status == "running", + "exit_code": getattr(container.attrs.get("State", {}), "ExitCode", None) + } + } + except DockerException as e: + return {"success": False, "message": str(e)} + def remove_container(self, container_name, force=True): try: container = self.client.containers.get(container_name) @@ -86,9 +161,9 @@ def remove_container(self, container_name, force=True): except DockerException as e: return {"success": False, "message": str(e)} - def create_hummingbot_instance(self, config: HummingbotInstanceConfig): + def create_hummingbot_instance(self, config: V2ScriptDeployment): bots_path = os.environ.get('BOTS_PATH', self.SOURCE_PATH) # Default to 'SOURCE_PATH' if BOTS_PATH is not set - instance_name = f"hummingbot-{config.instance_name}" + instance_name = config.instance_name instance_dir = os.path.join("bots", 'instances', instance_name) if not os.path.exists(instance_dir): os.makedirs(instance_dir) @@ -97,11 +172,7 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): # Copy credentials to instance directory source_credentials_dir = os.path.join("bots", 'credentials', config.credentials_profile) - script_config_dir = os.path.join("bots", 'conf', 'scripts') - controllers_config_dir = os.path.join("bots", 'conf', 'controllers') destination_credentials_dir = os.path.join(instance_dir, 'conf') - destination_scripts_config_dir = os.path.join(instance_dir, 'conf', 'scripts') - destination_controllers_config_dir = os.path.join(instance_dir, 'conf', 'controllers') # Remove the destination directory if it already exists if os.path.exists(destination_credentials_dir): @@ -109,12 +180,53 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): # Copy the entire contents of source_credentials_dir to destination_credentials_dir shutil.copytree(source_credentials_dir, destination_credentials_dir) - shutil.copytree(script_config_dir, destination_scripts_config_dir) - shutil.copytree(controllers_config_dir, destination_controllers_config_dir) - conf_file_path = f"{instance_dir}/conf/conf_client.yml" - client_config = FileSystemUtil.read_yaml_file(conf_file_path) + + # Copy specific script config and referenced controllers if provided + if config.script_config: + script_config_dir = os.path.join("bots", 'conf', 'scripts') + controllers_config_dir = os.path.join("bots", 'conf', 'controllers') + destination_scripts_config_dir = os.path.join(instance_dir, 'conf', 'scripts') + destination_controllers_config_dir = os.path.join(instance_dir, 'conf', 'controllers') + + os.makedirs(destination_scripts_config_dir, exist_ok=True) + + # Copy the specific script config file + source_script_config_file = os.path.join(script_config_dir, config.script_config) + destination_script_config_file = os.path.join(destination_scripts_config_dir, config.script_config) + + if os.path.exists(source_script_config_file): + shutil.copy2(source_script_config_file, destination_script_config_file) + + # Load the script config to find referenced controllers + try: + # Path relative to fs_util base_path (which is "bots") + script_config_relative_path = f"conf/scripts/{config.script_config}" + script_config_content = fs_util.read_yaml_file(script_config_relative_path) + controllers_list = script_config_content.get('controllers_config', []) + + # If there are controllers referenced, copy them + if controllers_list: + os.makedirs(destination_controllers_config_dir, exist_ok=True) + + for controller_file in controllers_list: + source_controller_file = os.path.join(controllers_config_dir, controller_file) + destination_controller_file = os.path.join(destination_controllers_config_dir, controller_file) + + if os.path.exists(source_controller_file): + shutil.copy2(source_controller_file, destination_controller_file) + logger.info(f"Copied controller config: {controller_file}") + else: + logger.warning(f"Controller config file {controller_file} not found in {controllers_config_dir}") + + except Exception as e: + logger.error(f"Error reading script config file {config.script_config}: {e}") + else: + logger.warning(f"Script config file {config.script_config} not found in {script_config_dir}") + # Path relative to fs_util base_path (which is "bots") + conf_file_path = f"instances/{instance_name}/conf/conf_client.yml" + client_config = fs_util.read_yaml_file(conf_file_path) client_config['instance_id'] = instance_name - FileSystemUtil.dump_dict_to_yaml(conf_file_path, client_config) + fs_util.dump_dict_to_yaml(conf_file_path, client_config) # Set up Docker volumes volumes = { @@ -130,7 +242,7 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): # Set up environment variables environment = {} - password = os.environ.get('CONFIG_PASSWORD', "a") + password = settings.security.config_password if password: environment["CONFIG_PASSWORD"] = password @@ -142,6 +254,9 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): else: return {"success": False, "message": "Password not provided. We cannot start the bot without a password."} + if config.headless: + environment["HEADLESS"] = "true" + log_config = LogConfig( type="json-file", config={ @@ -163,3 +278,144 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): return {"success": True, "message": f"Instance {instance_name} created successfully."} except docker.errors.DockerException as e: return {"success": False, "message": str(e)} + + def _start_cleanup_thread(self): + """Start the background cleanup thread""" + if self._cleanup_thread is None or not self._cleanup_thread.is_alive(): + self._cleanup_thread = threading.Thread(target=self._periodic_cleanup, daemon=True) + self._cleanup_thread.start() + logger.info("Started Docker pull status cleanup thread") + + def _periodic_cleanup(self): + """Periodically clean up old pull status entries""" + while not self._stop_cleanup.is_set(): + try: + self._cleanup_old_pull_status() + except Exception as e: + logger.error(f"Error in cleanup thread: {e}") + + # Wait for the next cleanup interval + self._stop_cleanup.wait(self.CLEANUP_INTERVAL_SECONDS) + + def _cleanup_old_pull_status(self): + """Remove old entries to prevent memory growth""" + current_time = time.time() + to_remove = [] + + # Find entries older than max age + for image_name, status_info in self._pull_status.items(): + # Skip ongoing pulls + if status_info["status"] == "pulling": + continue + + # Check age of completed/failed operations + end_time = status_info.get("completed_at") or status_info.get("failed_at") + if end_time and (current_time - end_time > self.PULL_STATUS_MAX_AGE_SECONDS): + to_remove.append(image_name) + + # Remove old entries + for image_name in to_remove: + del self._pull_status[image_name] + logger.info(f"Cleaned up old pull status for {image_name}") + + # If still over limit, remove oldest completed/failed entries + if len(self._pull_status) > self.PULL_STATUS_MAX_ENTRIES: + completed_entries = [ + (name, info) for name, info in self._pull_status.items() + if info["status"] in ["completed", "failed"] + ] + # Sort by end time (oldest first) + completed_entries.sort( + key=lambda x: x[1].get("completed_at") or x[1].get("failed_at") or 0 + ) + + # Remove oldest entries to get under limit + excess_count = len(self._pull_status) - self.PULL_STATUS_MAX_ENTRIES + for i in range(min(excess_count, len(completed_entries))): + del self._pull_status[completed_entries[i][0]] + logger.info(f"Cleaned up excess pull status for {completed_entries[i][0]}") + + def pull_image_async(self, image_name: str): + """Start pulling a Docker image asynchronously with status tracking""" + # Check if pull is already in progress + if image_name in self._pull_status: + current_status = self._pull_status[image_name] + if current_status["status"] == "pulling": + return { + "message": f"Pull already in progress for {image_name}", + "status": "in_progress", + "started_at": current_status["started_at"], + "image_name": image_name + } + + # Start the pull in a background thread + threading.Thread(target=self._pull_image_with_tracking, args=(image_name,), daemon=True).start() + + return { + "message": f"Pull started for {image_name}", + "status": "started", + "image_name": image_name + } + + def _pull_image_with_tracking(self, image_name: str): + """Background task to pull Docker image with status tracking""" + try: + self._pull_status[image_name] = { + "status": "pulling", + "started_at": time.time(), + "progress": "Starting pull..." + } + + # Use the synchronous pull method + result = self.pull_image_sync(image_name) + + if result.get("success"): + self._pull_status[image_name] = { + "status": "completed", + "started_at": self._pull_status[image_name]["started_at"], + "completed_at": time.time(), + "result": result + } + else: + self._pull_status[image_name] = { + "status": "failed", + "started_at": self._pull_status[image_name]["started_at"], + "failed_at": time.time(), + "error": result.get("error", "Unknown error") + } + except Exception as e: + self._pull_status[image_name] = { + "status": "failed", + "started_at": self._pull_status[image_name].get("started_at", time.time()), + "failed_at": time.time(), + "error": str(e) + } + + def get_all_pull_status(self): + """Get status of all pull operations""" + operations = {} + for image_name, status_info in self._pull_status.items(): + status_copy = status_info.copy() + + # Add duration for each operation + start_time = status_copy.get("started_at") + if start_time: + if status_copy["status"] == "pulling": + status_copy["duration_seconds"] = round(time.time() - start_time, 2) + elif "completed_at" in status_copy: + status_copy["duration_seconds"] = round(status_copy["completed_at"] - start_time, 2) + elif "failed_at" in status_copy: + status_copy["duration_seconds"] = round(status_copy["failed_at"] - start_time, 2) + + operations[image_name] = status_copy + + return { + "pull_operations": operations, + "total_operations": len(operations) + } + + def cleanup(self): + """Clean up resources when shutting down""" + self._stop_cleanup.set() + if self._cleanup_thread: + self._cleanup_thread.join(timeout=1) diff --git a/services/funding_recorder.py b/services/funding_recorder.py new file mode 100644 index 00000000..a12ebcdf --- /dev/null +++ b/services/funding_recorder.py @@ -0,0 +1,142 @@ +import asyncio +import logging +from datetime import datetime +from decimal import Decimal, InvalidOperation +from typing import Dict, Optional + +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder +from hummingbot.core.event.events import MarketEvent, FundingPaymentCompletedEvent + +from database import AsyncDatabaseManager, FundingRepository + + +class FundingRecorder: + """ + Records funding payment events and associates them with position data. + Follows the same pattern as OrdersRecorder for consistency. + """ + + def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connector_name: str): + self.db_manager = db_manager + self.account_name = account_name + self.connector_name = connector_name + self._connector: Optional[ConnectorBase] = None + self.logger = logging.getLogger(__name__) + + # Create event forwarder for funding payments + self._funding_payment_forwarder = SourceInfoEventForwarder(self._did_funding_payment) + + # Event pairs mapping events to forwarders + self._event_pairs = [ + (MarketEvent.FundingPaymentCompleted, self._funding_payment_forwarder), + ] + + def start(self, connector: ConnectorBase): + """Start recording funding payments for the given connector""" + self._connector = connector + + # Subscribe to funding payment events + for event, forwarder in self._event_pairs: + connector.add_listener(event, forwarder) + + self.logger.info(f"FundingRecorder started for {self.account_name}/{self.connector_name}") + + async def stop(self): + """Stop recording funding payments""" + if self._connector: + for event, forwarder in self._event_pairs: + self._connector.remove_listener(event, forwarder) + self.logger.info(f"FundingRecorder stopped for {self.account_name}/{self.connector_name}") + + def _did_funding_payment(self, event_tag: int, market: ConnectorBase, event: FundingPaymentCompletedEvent): + """Handle funding payment events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_funding_payment(event)) + except Exception as e: + self.logger.error(f"Error in _did_funding_payment: {e}") + + async def _handle_funding_payment(self, event: FundingPaymentCompletedEvent): + """Handle funding payment events""" + # Get current position data if available + position_data = None + if self._connector and hasattr(self._connector, 'account_positions'): + try: + positions = self._connector.account_positions + if positions: + for position in positions.values(): + if position.trading_pair == event.trading_pair: + position_data = { + "size": float(position.amount), + "side": position.position_side.name if hasattr(position.position_side, 'name') else str(position.position_side), + } + break + except Exception as e: + self.logger.warning(f"Could not get position data for funding payment: {e}") + + # Record the funding payment + await self.record_funding_payment(event, self.account_name, self.connector_name, position_data) + + async def record_funding_payment(self, event: FundingPaymentCompletedEvent, + account_name: str, connector_name: str, + position_data: Optional[Dict] = None): + """ + Record a funding payment event with optional position association. + + Args: + event: FundingPaymentCompletedEvent from Hummingbot + account_name: Account name + connector_name: Connector name + position_data: Optional position data at time of payment + """ + try: + # Validate and convert funding data + funding_rate = Decimal(str(event.funding_rate)) + funding_payment = Decimal(str(event.amount)) + + # Create funding payment record + funding_data = { + "funding_payment_id": f"{connector_name}_{event.trading_pair}_{event.timestamp.timestamp()}", + "timestamp": event.timestamp, + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": event.trading_pair, + "funding_rate": float(funding_rate), + "funding_payment": float(funding_payment), + "fee_currency": getattr(event, 'fee_currency', 'USDT'), # Default to USDT if not provided + "exchange_funding_id": getattr(event, 'exchange_funding_id', None), + } + + # Add position data if provided + if position_data: + funding_data.update({ + "position_size": float(position_data.get("size", 0)), + "position_side": position_data.get("side"), + }) + + # Save to database + async with self.db_manager.get_session() as session: + funding_repo = FundingRepository(session) + + # Check if funding payment already exists + if await funding_repo.funding_payment_exists(funding_data["funding_payment_id"]): + self.logger.info(f"Funding payment {funding_data['funding_payment_id']} already exists, skipping") + return + + funding_payment = await funding_repo.create_funding_payment(funding_data) + await session.commit() + + self.logger.info( + f"Recorded funding payment for {account_name}/{connector_name}: " + f"{event.trading_pair} - Rate: {funding_rate}, Payment: {funding_payment} " + f"{funding_data['fee_currency']}" + ) + + return funding_payment + + except (ValueError, InvalidOperation) as e: + self.logger.error(f"Error processing funding payment for {event.trading_pair}: {e}, skipping update") + return + except Exception as e: + self.logger.error(f"Unexpected error recording funding payment: {e}") + return \ No newline at end of file diff --git a/services/market_data_feed_manager.py b/services/market_data_feed_manager.py new file mode 100644 index 00000000..115734a3 --- /dev/null +++ b/services/market_data_feed_manager.py @@ -0,0 +1,601 @@ +import asyncio +import time +from typing import Dict, Optional, Callable, List +import logging +from enum import Enum + +from hummingbot.core.rate_oracle.rate_oracle import RateOracle +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.data_feed.market_data_provider import MarketDataProvider + + +class FeedType(Enum): + """Types of market data feeds that can be managed.""" + CANDLES = "candles" + ORDER_BOOK = "order_book" + TRADES = "trades" + TICKER = "ticker" + + +class MarketDataFeedManager: + """ + Generic manager for market data feeds lifecycle with automatic cleanup. + + This service wraps the MarketDataProvider and tracks when any type of market data feed + is last accessed. Feeds that haven't been accessed within the specified timeout period + are automatically stopped and cleaned up. + """ + + def __init__(self, market_data_provider: MarketDataProvider, rate_oracle: RateOracle, cleanup_interval: int = 300, feed_timeout: int = 600): + """ + Initialize the MarketDataFeedManager. + + Args: + market_data_provider: The underlying MarketDataProvider instance + cleanup_interval: How often to run cleanup (seconds, default: 5 minutes) + feed_timeout: How long to keep unused feeds alive (seconds, default: 10 minutes) + """ + self.market_data_provider = market_data_provider + self.rate_oracle = rate_oracle + self.cleanup_interval = cleanup_interval + self.feed_timeout = feed_timeout + self.last_access_times: Dict[str, float] = {} + self.feed_configs: Dict[str, tuple] = {} # Store feed configs for cleanup + self._cleanup_task: Optional[asyncio.Task] = None + self._is_running = False + self.logger = logging.getLogger(__name__) + + # Registry of cleanup functions for different feed types + self._cleanup_functions: Dict[FeedType, Callable] = { + FeedType.CANDLES: self._cleanup_candle_feed, + FeedType.ORDER_BOOK: self._cleanup_order_book_feed, + # Add more feed types as needed + } + + def start(self): + """Start the cleanup background task.""" + if not self._is_running: + self._is_running = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + self.rate_oracle.start() + self.logger.info(f"MarketDataFeedManager started with cleanup_interval={self.cleanup_interval}s, feed_timeout={self.feed_timeout}s") + + def stop(self): + """Stop the cleanup background task and all feeds.""" + self._is_running = False + if self._cleanup_task: + self._cleanup_task.cancel() + self._cleanup_task = None + + # Stop all feeds managed by the MarketDataProvider + self.market_data_provider.stop() + self.last_access_times.clear() + self.feed_configs.clear() + self.logger.info("MarketDataFeedManager stopped") + + def get_candles_feed(self, config: CandlesConfig): + """ + Get a candles feed and update its last access time. + + Args: + config: CandlesConfig for the desired feed + + Returns: + Candle feed instance + """ + feed_key = self._generate_feed_key(FeedType.CANDLES, config.connector, config.trading_pair, config.interval) + + # Update last access time and store config for cleanup + self.last_access_times[feed_key] = time.time() + self.feed_configs[feed_key] = (FeedType.CANDLES, config) + + # Get the feed from MarketDataProvider + feed = self.market_data_provider.get_candles_feed(config) + + self.logger.debug(f"Accessed candle feed: {feed_key}") + return feed + + def get_candles_df(self, connector_name: str, trading_pair: str, interval: str, max_records: int = 500): + """ + Get candles dataframe and update access time. + + Args: + connector_name: The connector name + trading_pair: The trading pair + interval: The candle interval + max_records: Maximum number of records + + Returns: + Candles dataframe + """ + config = CandlesConfig( + connector=connector_name, + trading_pair=trading_pair, + interval=interval, + max_records=max_records + ) + + feed_key = self._generate_feed_key(FeedType.CANDLES, connector_name, trading_pair, interval) + self.last_access_times[feed_key] = time.time() + self.feed_configs[feed_key] = (FeedType.CANDLES, config) + + # Use MarketDataProvider's convenience method + df = self.market_data_provider.get_candles_df(connector_name, trading_pair, interval, max_records) + + self.logger.debug(f"Accessed candle data: {feed_key}") + return df + + def get_order_book(self, connector_name: str, trading_pair: str): + """ + Get order book and update access time. + + Args: + connector_name: The connector name + trading_pair: The trading pair + + Returns: + Order book instance + """ + feed_key = self._generate_feed_key(FeedType.ORDER_BOOK, connector_name, trading_pair) + + # Update last access time + self.last_access_times[feed_key] = time.time() + self.feed_configs[feed_key] = (FeedType.ORDER_BOOK, (connector_name, trading_pair)) + + # Get order book from MarketDataProvider + order_book = self.market_data_provider.get_order_book(connector_name, trading_pair) + + self.logger.debug(f"Accessed order book: {feed_key}") + return order_book + + def get_order_book_snapshot(self, connector_name: str, trading_pair: str): + """ + Get order book snapshot and update access time. + + Args: + connector_name: The connector name + trading_pair: The trading pair + + Returns: + Tuple of bid and ask DataFrames + """ + feed_key = self._generate_feed_key(FeedType.ORDER_BOOK, connector_name, trading_pair) + + # Update last access time + self.last_access_times[feed_key] = time.time() + self.feed_configs[feed_key] = (FeedType.ORDER_BOOK, (connector_name, trading_pair)) + + # Get order book snapshot from MarketDataProvider + snapshot = self.market_data_provider.get_order_book_snapshot(connector_name, trading_pair) + + self.logger.debug(f"Accessed order book snapshot: {feed_key}") + return snapshot + + async def get_trading_rules(self, connector_name: str, trading_pairs: Optional[List[str]] = None) -> Dict[str, Dict]: + """ + Get trading rules for specified trading pairs from a connector. + + Args: + connector_name: Name of the connector + trading_pairs: List of trading pairs to get rules for. If None, get all available. + + Returns: + Dictionary mapping trading pairs to their trading rules + """ + try: + # Access connector through MarketDataProvider's _rate_sources LazyDict + connector = self.market_data_provider._rate_sources[connector_name] + + # Check if trading rules are initialized, if not update them + if not connector.trading_rules or len(connector.trading_rules) == 0: + await connector._update_trading_rules() + + # Get trading rules + if trading_pairs: + # Get rules for specific trading pairs + result = {} + for trading_pair in trading_pairs: + if trading_pair in connector.trading_rules: + rule = connector.trading_rules[trading_pair] + result[trading_pair] = { + "min_order_size": float(rule.min_order_size), + "max_order_size": float(rule.max_order_size) if rule.max_order_size else None, + "min_price_increment": float(rule.min_price_increment), + "min_base_amount_increment": float(rule.min_base_amount_increment), + "min_quote_amount_increment": float(rule.min_quote_amount_increment), + "min_notional_size": float(rule.min_notional_size), + "min_order_value": float(rule.min_order_value), + "max_price_significant_digits": float(rule.max_price_significant_digits), + "supports_limit_orders": rule.supports_limit_orders, + "supports_market_orders": rule.supports_market_orders, + "buy_order_collateral_token": rule.buy_order_collateral_token, + "sell_order_collateral_token": rule.sell_order_collateral_token, + } + else: + result[trading_pair] = {"error": f"Trading pair {trading_pair} not found"} + else: + # Get all trading rules + result = {} + for trading_pair, rule in connector.trading_rules.items(): + result[trading_pair] = { + "min_order_size": float(rule.min_order_size), + "max_order_size": float(rule.max_order_size) if rule.max_order_size else None, + "min_price_increment": float(rule.min_price_increment), + "min_base_amount_increment": float(rule.min_base_amount_increment), + "min_quote_amount_increment": float(rule.min_quote_amount_increment), + "min_notional_size": float(rule.min_notional_size), + "min_order_value": float(rule.min_order_value), + "max_price_significant_digits": float(rule.max_price_significant_digits), + "supports_limit_orders": rule.supports_limit_orders, + "supports_market_orders": rule.supports_market_orders, + "buy_order_collateral_token": rule.buy_order_collateral_token, + "sell_order_collateral_token": rule.sell_order_collateral_token, + } + + self.logger.debug(f"Retrieved trading rules for {connector_name}: {len(result)} pairs") + return result + + except Exception as e: + self.logger.error(f"Error getting trading rules for {connector_name}: {e}") + return {"error": str(e)} + + async def get_prices(self, connector_name: str, trading_pairs: List[str]) -> Dict[str, float]: + """ + Get current prices for specified trading pairs. + + Args: + connector_name: Name of the connector + trading_pairs: List of trading pairs to get prices for + + Returns: + Dictionary mapping trading pairs to their current prices + """ + try: + # Access connector through MarketDataProvider's _rate_sources LazyDict + connector = self.market_data_provider._rate_sources[connector_name] + + # Get last traded prices + prices = await connector.get_last_traded_prices(trading_pairs) + + # Convert Decimal to float for JSON serialization + result = {pair: float(price) for pair, price in prices.items()} + + self.logger.debug(f"Retrieved prices for {connector_name}: {len(result)} pairs") + return result + + except Exception as e: + self.logger.error(f"Error getting prices for {connector_name}: {e}") + return {"error": str(e)} + + async def get_funding_info(self, connector_name: str, trading_pair: str) -> Dict: + """ + Get funding information for a perpetual trading pair. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to get funding info for + + Returns: + Dictionary containing funding information + """ + try: + # Access connector through MarketDataProvider's _rate_sources LazyDict + connector = self.market_data_provider._rate_sources[connector_name] + + # Check if this is a perpetual connector and has funding info support + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + orderbook_ds = connector._orderbook_ds + + # Get funding info from the order book data source + funding_info = await orderbook_ds.get_funding_info(trading_pair) + + if funding_info: + result = { + "trading_pair": trading_pair, + "funding_rate": float(funding_info.rate) if funding_info.rate else None, + "next_funding_time": float(funding_info.next_funding_utc_timestamp) if funding_info.next_funding_utc_timestamp else None, + "mark_price": float(funding_info.mark_price) if funding_info.mark_price else None, + "index_price": float(funding_info.index_price) if funding_info.index_price else None, + } + + self.logger.debug(f"Retrieved funding info for {connector_name}/{trading_pair}") + return result + else: + return {"error": f"No funding info available for {trading_pair}"} + else: + return {"error": f"Funding info not supported for {connector_name}"} + + except Exception as e: + self.logger.error(f"Error getting funding info for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + async def get_order_book_data(self, connector_name: str, trading_pair: str, depth: int = 10) -> Dict: + """ + Get order book data using the connector's order book data source. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to get order book for + depth: Number of bid/ask levels to return + + Returns: + Dictionary containing bid and ask data + """ + try: + # Access connector through MarketDataProvider's _rate_sources LazyDict + connector = self.market_data_provider._rate_sources[connector_name] + + # Access the order book data source + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + orderbook_ds = connector._orderbook_ds + + # Get new order book using the data source method + order_book = await orderbook_ds.get_new_order_book(trading_pair) + snapshot = order_book.snapshot + + result = { + "trading_pair": trading_pair, + "bids": snapshot[0].loc[:(depth - 1), ["price", "amount"]].values.tolist(), + "asks": snapshot[1].loc[:(depth - 1), ["price", "amount"]].values.tolist(), + "timestamp": time.time() + } + + self.logger.debug(f"Retrieved order book for {connector_name}/{trading_pair}") + return result + else: + return {"error": f"Order book data source not available for {connector_name}"} + + except Exception as e: + self.logger.error(f"Error getting order book for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + async def get_order_book_query_result(self, connector_name: str, trading_pair: str, is_buy: bool, **kwargs) -> Dict: + """ + Generic method for order book queries using fresh OrderBook from data source. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + is_buy: True for buy side, False for sell side + **kwargs: Additional parameters for specific query types + + Returns: + Dictionary containing query results + """ + try: + current_time = time.time() + + # Access connector through MarketDataProvider's _rate_sources LazyDict + connector = self.market_data_provider._rate_sources[connector_name] + + # Access the order book data source + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + orderbook_ds = connector._orderbook_ds + + # Get fresh order book using the data source method + order_book = await orderbook_ds.get_new_order_book(trading_pair) + + if 'volume' in kwargs: + # Get price for volume + result = order_book.get_price_for_volume(is_buy, kwargs['volume']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_volume": kwargs['volume'], + "result_price": float(result.result_price) if result.result_price else None, + "result_volume": float(result.result_volume) if result.result_volume else None, + "timestamp": current_time + } + + elif 'price' in kwargs: + # Get volume for price + result = order_book.get_volume_for_price(is_buy, kwargs['price']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_price": kwargs['price'], + "result_volume": float(result.result_volume) if result.result_volume else None, + "result_price": float(result.result_price) if result.result_price else None, + "timestamp": current_time + } + + elif 'quote_volume' in kwargs: + # Get price for quote volume + result = order_book.get_price_for_quote_volume(is_buy, kwargs['quote_volume']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_quote_volume": kwargs['quote_volume'], + "result_price": float(result.result_price) if result.result_price else None, + "result_volume": float(result.result_volume) if result.result_volume else None, + "timestamp": current_time + } + + elif 'quote_price' in kwargs: + # Get quote volume for price + result = order_book.get_quote_volume_for_price(is_buy, kwargs['quote_price']) + + # Check if quote crosses the book (no available volume at this price) + if result.result_volume is None or result.result_price is None: + # Get current market prices for comparison + snapshot = order_book.snapshot + best_bid = float(snapshot[0].iloc[0]["price"]) if not snapshot[0].empty else None + best_ask = float(snapshot[1].iloc[0]["price"]) if not snapshot[1].empty else None + mid_price = (best_bid + best_ask) / 2 if best_bid and best_ask else None + + # Determine if quote crosses the book + query_price = float(kwargs['quote_price']) + crossed_reason = None + suggested_price = None + + if is_buy: + # For buy orders, crossing occurs when price > best_ask + if best_ask and query_price > best_ask: + crossed_reason = f"Buy price {query_price} exceeds best ask {best_ask}" + suggested_price = best_ask + elif best_bid and query_price < best_bid: + crossed_reason = f"Buy price {query_price} below best bid {best_bid} - no liquidity available" + suggested_price = best_bid + else: + # For sell orders, crossing occurs when price < best_bid + if best_bid and query_price < best_bid: + crossed_reason = f"Sell price {query_price} below best bid {best_bid}" + suggested_price = best_bid + elif best_ask and query_price > best_ask: + crossed_reason = f"Sell price {query_price} above best ask {best_ask} - no liquidity available" + suggested_price = best_ask + + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_price": query_price, + "result_volume": None, + "result_quote_volume": None, + "crossed_book": True, + "crossed_reason": crossed_reason, + "best_bid": best_bid, + "best_ask": best_ask, + "mid_price": mid_price, + "suggested_price": suggested_price, + "timestamp": current_time + } + + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_price": kwargs['quote_price'], + "result_quote_volume": float(result.result_volume) if result.result_volume else None, + "crossed_book": False, + "timestamp": current_time + } + + elif 'vwap_volume' in kwargs: + # Get VWAP for volume + result = order_book.get_vwap_for_volume(is_buy, kwargs['vwap_volume']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_volume": kwargs['vwap_volume'], + "average_price": float(result.result_price) if result.result_price else None, + "result_volume": float(result.result_volume) if result.result_volume else None, + "timestamp": current_time + } + else: + return {"error": "Invalid query parameters"} + else: + return {"error": f"Order book data source not available for {connector_name}"} + + except Exception as e: + self.logger.error(f"Error in order book query for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + async def _cleanup_loop(self): + """Background task that periodically cleans up unused feeds.""" + while self._is_running: + try: + await self._cleanup_unused_feeds() + await asyncio.sleep(self.cleanup_interval) + except asyncio.CancelledError: + break + except Exception as e: + self.logger.error(f"Error in cleanup loop: {e}", exc_info=True) + await asyncio.sleep(self.cleanup_interval) + + async def _cleanup_unused_feeds(self): + """Clean up feeds that haven't been accessed within the timeout period.""" + current_time = time.time() + feeds_to_remove = [] + + for feed_key, last_access_time in self.last_access_times.items(): + if current_time - last_access_time > self.feed_timeout: + feeds_to_remove.append(feed_key) + + for feed_key in feeds_to_remove: + try: + # Get feed type and config + feed_type, config = self.feed_configs[feed_key] + + # Use appropriate cleanup function + cleanup_func = self._cleanup_functions.get(feed_type) + if cleanup_func: + cleanup_func(config) + + # Remove from tracking + del self.last_access_times[feed_key] + del self.feed_configs[feed_key] + + self.logger.info(f"Cleaned up unused {feed_type.value} feed: {feed_key}") + + except Exception as e: + self.logger.error(f"Error cleaning up feed {feed_key}: {e}", exc_info=True) + + if feeds_to_remove: + self.logger.info(f"Cleaned up {len(feeds_to_remove)} unused market data feeds") + + def _cleanup_candle_feed(self, config: CandlesConfig): + """Clean up a candle feed.""" + self.market_data_provider.stop_candle_feed(config) + + def _cleanup_order_book_feed(self, config: tuple): + """Clean up an order book feed.""" + # Order books are typically managed by connectors, so we might not need explicit cleanup + # This is a placeholder for future implementation if needed + pass + + def _generate_feed_key(self, feed_type: FeedType, connector: str, trading_pair: str, interval: str = None) -> str: + """Generate a unique key for a market data feed.""" + if interval: + return f"{feed_type.value}_{connector}_{trading_pair}_{interval}" + else: + return f"{feed_type.value}_{connector}_{trading_pair}" + + def get_active_feeds_info(self) -> Dict[str, dict]: + """ + Get information about currently active feeds. + + Returns: + Dictionary with feed information including last access times and feed types + """ + current_time = time.time() + result = {} + + for feed_key, last_access in self.last_access_times.items(): + feed_type, config = self.feed_configs.get(feed_key, (None, None)) + result[feed_key] = { + "feed_type": feed_type.value if feed_type else "unknown", + "last_access_time": last_access, + "seconds_since_access": current_time - last_access, + "will_expire_in": max(0, self.feed_timeout - (current_time - last_access)), + "config": str(config) # String representation of config + } + + return result + + def manually_cleanup_feed(self, feed_type: FeedType, connector: str, trading_pair: str, interval: str = None): + """ + Manually cleanup a specific feed. + + Args: + feed_type: Type of feed to cleanup + connector: Connector name + trading_pair: Trading pair + interval: Interval (for candles only) + """ + feed_key = self._generate_feed_key(feed_type, connector, trading_pair, interval) + + if feed_key in self.feed_configs: + feed_type_obj, config = self.feed_configs[feed_key] + cleanup_func = self._cleanup_functions.get(feed_type_obj) + + if cleanup_func: + try: + cleanup_func(config) + del self.last_access_times[feed_key] + del self.feed_configs[feed_key] + self.logger.info(f"Manually cleaned up feed: {feed_key}") + except Exception as e: + self.logger.error(f"Error manually cleaning up feed {feed_key}: {e}", exc_info=True) + else: + self.logger.warning(f"No cleanup function for feed type: {feed_type}") + else: + self.logger.warning(f"Feed not found for cleanup: {feed_key}") \ No newline at end of file diff --git a/services/orders_recorder.py b/services/orders_recorder.py new file mode 100644 index 00000000..caa6d759 --- /dev/null +++ b/services/orders_recorder.py @@ -0,0 +1,376 @@ +import asyncio +import logging +import math +import time + +from typing import Any, Optional, Union +from datetime import datetime +from decimal import Decimal, InvalidOperation + +from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder +from hummingbot.core.event.events import ( + TradeType, + BuyOrderCreatedEvent, + SellOrderCreatedEvent, + OrderFilledEvent, + MarketEvent +) +from hummingbot.connector.connector_base import ConnectorBase +from database import AsyncDatabaseManager, OrderRepository, TradeRepository + +# Initialize logger +logger = logging.getLogger(__name__) + + +class OrdersRecorder: + """ + Custom orders recorder that mimics Hummingbot's MarketsRecorder functionality + but uses our AsyncDatabaseManager for storage. + """ + + def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connector_name: str): + self.db_manager = db_manager + self.account_name = account_name + self.connector_name = connector_name + self._connector: Optional[ConnectorBase] = None + + # Create event forwarders similar to MarketsRecorder + self._create_order_forwarder = SourceInfoEventForwarder(self._did_create_order) + self._fill_order_forwarder = SourceInfoEventForwarder(self._did_fill_order) + self._cancel_order_forwarder = SourceInfoEventForwarder(self._did_cancel_order) + self._fail_order_forwarder = SourceInfoEventForwarder(self._did_fail_order) + self._complete_order_forwarder = SourceInfoEventForwarder(self._did_complete_order) + + # Event pairs mapping events to forwarders + self._event_pairs = [ + (MarketEvent.BuyOrderCreated, self._create_order_forwarder), + (MarketEvent.SellOrderCreated, self._create_order_forwarder), + (MarketEvent.OrderFilled, self._fill_order_forwarder), + (MarketEvent.OrderCancelled, self._cancel_order_forwarder), + (MarketEvent.OrderFailure, self._fail_order_forwarder), + (MarketEvent.BuyOrderCompleted, self._complete_order_forwarder), + (MarketEvent.SellOrderCompleted, self._complete_order_forwarder), + ] + + def start(self, connector: ConnectorBase): + """Start recording orders for the given connector""" + self._connector = connector + + # Subscribe to order events using the same pattern as MarketsRecorder + for event, forwarder in self._event_pairs: + connector.add_listener(event, forwarder) + logger.info(f"OrdersRecorder: Added listener for {event} with forwarder {forwarder}") + + # Debug: Check if listeners were actually added + if hasattr(connector, '_event_listeners'): + listeners = connector._event_listeners.get(event, []) + logger.info(f"OrdersRecorder: Event {event} now has {len(listeners)} listeners") + for i, listener in enumerate(listeners): + logger.info(f"OrdersRecorder: Listener {i}: {listener}") + + logger.info(f"OrdersRecorder started for {self.account_name}/{self.connector_name} with {len(self._event_pairs)} event listeners") + + # Debug: Print connector info + logger.info(f"OrdersRecorder: Connector type: {type(connector)}") + logger.info(f"OrdersRecorder: Connector name: {getattr(connector, 'name', 'unknown')}") + logger.info(f"OrdersRecorder: Connector ready: {getattr(connector, 'ready', 'unknown')}") + + # Test if forwarders are callable + for event, forwarder in self._event_pairs: + if callable(forwarder): + logger.info(f"OrdersRecorder: Forwarder for {event} is callable") + else: + logger.error(f"OrdersRecorder: Forwarder for {event} is NOT callable: {type(forwarder)}") + + async def stop(self): + """Stop recording orders""" + if self._connector: + # Remove all event listeners + for event, forwarder in self._event_pairs: + self._connector.remove_listener(event, forwarder) + + logger.info(f"OrdersRecorder stopped for {self.account_name}/{self.connector_name}") + + def _extract_error_message(self, event) -> str: + """Extract error message from various possible event attributes.""" + # Try different possible attribute names for error messages + for attr_name in ['error_message', 'message', 'reason', 'failure_reason', 'error']: + if hasattr(event, attr_name): + error_value = getattr(event, attr_name) + if error_value: + return str(error_value) + + # If no error message found, create a descriptive one + return f"Order failed: {event.__class__.__name__}" + + def _did_create_order(self, event_tag: int, market: ConnectorBase, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent]): + """Handle order creation events - called by SourceInfoEventForwarder""" + logger.info(f"OrdersRecorder: _did_create_order called for order {getattr(event, 'order_id', 'unknown')}") + try: + # Determine trade type from event + trade_type = TradeType.BUY if isinstance(event, BuyOrderCreatedEvent) else TradeType.SELL + logger.info(f"OrdersRecorder: Creating task to handle order created - {trade_type} order") + asyncio.create_task(self._handle_order_created(event, trade_type)) + except Exception as e: + logger.error(f"Error in _did_create_order: {e}") + + def _did_fill_order(self, event_tag: int, market: ConnectorBase, event: OrderFilledEvent): + """Handle order fill events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_filled(event)) + except Exception as e: + logger.error(f"Error in _did_fill_order: {e}") + + def _did_cancel_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order cancel events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_cancelled(event)) + except Exception as e: + logger.error(f"Error in _did_cancel_order: {e}") + + def _did_fail_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order failure events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_failed(event)) + except Exception as e: + logger.error(f"Error in _did_fail_order: {e}") + + def _did_complete_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order completion events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_completed(event)) + except Exception as e: + logger.error(f"Error in _did_complete_order: {e}") + + async def _handle_order_created(self, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent], trade_type: TradeType): + """Handle order creation events""" + logger.info(f"OrdersRecorder: _handle_order_created started for order {event.order_id}") + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + + # Check if order already exists first + existing_order = await order_repo.get_order_by_client_id(event.order_id) + if existing_order: + logger.info(f"OrdersRecorder: Order {event.order_id} already exists with status {existing_order.status}") + + # Update exchange_order_id if we have it now and it was missing + exchange_order_id = getattr(event, 'exchange_order_id', None) + if exchange_order_id and not existing_order.exchange_order_id: + existing_order.exchange_order_id = exchange_order_id + logger.info(f"OrdersRecorder: Updated exchange_order_id to {exchange_order_id} for order {event.order_id}") + + # Update status if it's still in PENDING_CREATE or similar early state + if existing_order.status in ["PENDING_CREATE", "PENDING"]: + existing_order.status = "SUBMITTED" + logger.info(f"OrdersRecorder: Updated status from {existing_order.status} to SUBMITTED for order {event.order_id}") + + await session.flush() + return + + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": event.trading_pair, + "trade_type": trade_type.name, + "order_type": event.type.name if hasattr(event, 'type') else 'UNKNOWN', + "amount": float(event.amount), + "price": float(event.price) if event.price else None, + "status": "SUBMITTED", + "exchange_order_id": getattr(event, 'exchange_order_id', None) + } + await order_repo.create_order(order_data) + + logger.info(f"OrdersRecorder: Successfully recorded order created: {event.order_id}") + except Exception as e: + logger.error(f"OrdersRecorder: Error recording order created: {e}") + + async def _handle_order_filled(self, event: OrderFilledEvent): + """Handle order fill events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + trade_repo = TradeRepository(session) + + # Calculate fees + trade_fee_paid = 0 + trade_fee_currency = None + + if event.trade_fee: + try: + base_asset, quote_asset = event.trading_pair.split("-") + fee_in_quote = event.trade_fee.fee_amount_in_token( + trading_pair=event.trading_pair, + price=event.price, + order_amount=event.amount, + token=quote_asset, + exchange=self._connector, + ) + trade_fee_paid = float(fee_in_quote) + trade_fee_currency = quote_asset + except Exception as e: + logger.error(f"Error calculating trade fee: {e}") + trade_fee_paid = 0 + trade_fee_currency = None + + # Update order with fill information (handle potential NaN values like Hummingbot does) + try: + filled_amount = Decimal(str(event.amount)) + average_fill_price = Decimal(str(event.price)) + fee_paid_decimal = Decimal(str(trade_fee_paid)) if trade_fee_paid else None + + order = await order_repo.update_order_fill( + client_order_id=event.order_id, + filled_amount=filled_amount, + average_fill_price=average_fill_price, + fee_paid=fee_paid_decimal, + fee_currency=trade_fee_currency + ) + except (ValueError, InvalidOperation) as e: + logger.error(f"Error processing order fill for {event.order_id}: {e}, skipping update") + return + + # Create trade record using validated values + if order: + try: + # Validate all values before creating trade record + validated_timestamp = event.timestamp if event.timestamp and not math.isnan(event.timestamp) else time.time() + validated_fee = trade_fee_paid if trade_fee_paid and not math.isnan(trade_fee_paid) else 0 + + trade_data = { + "order_id": order.id, + "trade_id": f"{event.order_id}_{validated_timestamp}", + "timestamp": datetime.fromtimestamp(validated_timestamp), + "trading_pair": event.trading_pair, + "trade_type": event.trade_type.name, + "amount": float(filled_amount), # Use validated amount + "price": float(average_fill_price), # Use validated price + "fee_paid": validated_fee, + "fee_currency": trade_fee_currency + } + await trade_repo.create_trade(trade_data) + except (ValueError, TypeError) as e: + logger.error(f"Error creating trade record for {event.order_id}: {e}") + logger.error(f"Trade data that failed: timestamp={event.timestamp}, amount={event.amount}, price={event.price}, fee={trade_fee_paid}") + + logger.debug(f"Recorded order fill: {event.order_id} - {event.amount} @ {event.price}") + except Exception as e: + logger.error(f"Error recording order fill: {e}") + + async def _handle_order_cancelled(self, event: Any): + """Handle order cancellation events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + await order_repo.update_order_status( + client_order_id=event.order_id, + status="CANCELLED" + ) + + logger.debug(f"Recorded order cancelled: {event.order_id}") + except Exception as e: + logger.error(f"Error recording order cancellation: {e}") + + def _get_order_details_from_connector(self, order_id: str) -> Optional[dict]: + """Try to get order details from connector's tracked orders""" + try: + if self._connector and hasattr(self._connector, 'in_flight_orders'): + in_flight_order = self._connector.in_flight_orders.get(order_id) + if in_flight_order: + return { + "trading_pair": in_flight_order.trading_pair, + "trade_type": in_flight_order.trade_type.name, + "order_type": in_flight_order.order_type.name, + "amount": float(in_flight_order.amount), + "price": float(in_flight_order.price) if in_flight_order.price else None + } + except Exception as e: + logger.error(f"Error getting order details from connector: {e}") + return None + + async def _handle_order_failed(self, event: Any): + """Handle order failure events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + + # Check if order exists, if not try to get details from connector's tracked orders + existing_order = await order_repo.get_order_by_client_id(event.order_id) + if existing_order: + # Extract error message from various possible attributes + error_msg = self._extract_error_message(event) + + # Update existing order with failure status and error message + await order_repo.update_order_status( + client_order_id=event.order_id, + status="FAILED", + error_message=error_msg + ) + logger.info(f"Updated existing order {event.order_id} to FAILED status") + else: + # Try to get order details from connector's tracked orders + order_details = self._get_order_details_from_connector(event.order_id) + if order_details: + logger.info(f"Retrieved order details from connector for {event.order_id}: {order_details}") + + # Create order record as FAILED with available details + if order_details: + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": order_details["trading_pair"], + "trade_type": order_details["trade_type"], + "order_type": order_details["order_type"], + "amount": order_details["amount"], + "price": order_details["price"], + "status": "FAILED", + "error_message": self._extract_error_message(event) + } + else: + # Fallback with minimal details + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": "UNKNOWN", + "trade_type": "UNKNOWN", + "order_type": "UNKNOWN", + "amount": 0.0, + "price": None, + "status": "FAILED", + "error_message": self._extract_error_message(event) + } + + try: + await order_repo.create_order(order_data) + logger.info(f"Created failed order record for {event.order_id}") + except Exception as create_error: + # If creation fails due to duplicate key, try to update existing order + if "duplicate key" in str(create_error).lower() or "unique constraint" in str(create_error).lower(): + logger.info(f"Order {event.order_id} already exists, updating status to FAILED") + await order_repo.update_order_status( + client_order_id=event.order_id, + status="FAILED", + error_message=self._extract_error_message(event) + ) + else: + raise create_error + + except Exception as e: + logger.error(f"Error recording order failure: {e}") + + async def _handle_order_completed(self, event: Any): + """Handle order completion events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + order = await order_repo.get_order_by_client_id(event.order_id) + if order: + order.status = "FILLED" + order.exchange_order_id = getattr(event, 'exchange_order_id', None) + + logger.debug(f"Recorded order completed: {event.order_id}") + except Exception as e: + logger.error(f"Error recording order completion: {e}") \ No newline at end of file diff --git a/set_environment.sh b/set_environment.sh deleted file mode 100644 index 18501980..00000000 --- a/set_environment.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -# Create or overwrite .env file -echo "Setting up .env file for the project... -By default, the current working directory will be used as the BOTS_PATH and the CONFIG_PASSWORD will be set to 'a'." - -# Asking for CONFIG_PASSWORD and BOTS_PATH -CONFIG_PASSWORD=a -USERNAME=admin -PASSWORD=admin -BOTS_PATH=$(pwd) - -# Write to .env file -echo "CONFIG_PASSWORD=$CONFIG_PASSWORD" > .env -echo "BOTS_PATH=$BOTS_PATH" >> .env -echo "USERNAME=$USERNAME" >> .env -echo "PASSWORD=$PASSWORD" >> .env diff --git a/setup.sh b/setup.sh new file mode 100755 index 00000000..6b491e3c --- /dev/null +++ b/setup.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# Backend API Setup Script +# This script creates a comprehensive .env file with all configuration options +# following the Pydantic Settings structure established in config.py + +set -e # Exit on any error + +# Colors for better output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +PURPLE='\033[0;35m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +echo "🚀 Backend API Setup" +echo "" + +echo -n "Config password [default: admin]: " +read CONFIG_PASSWORD +CONFIG_PASSWORD=${CONFIG_PASSWORD:-admin} + +echo -n "API username [default: admin]: " +read USERNAME +USERNAME=${USERNAME:-admin} + +echo -n "API password [default: admin]: " +read PASSWORD +PASSWORD=${PASSWORD:-admin} + +# Set paths and defaults +BOTS_PATH=$(pwd) + +# Use sensible defaults for everything else +DEBUG_MODE="false" +BROKER_HOST="localhost" +BROKER_PORT="1883" +BROKER_USERNAME="admin" +BROKER_PASSWORD="password" +DATABASE_URL="postgresql+asyncpg://hbot:hummingbot-api@localhost:5432/hummingbot_api" +CLEANUP_INTERVAL="300" +FEED_TIMEOUT="600" +AWS_API_KEY="" +AWS_SECRET_KEY="" +S3_BUCKET="" +LOGFIRE_ENV="dev" +BANNED_TOKENS='["NAV","ARS","ETHW","ETHF","NEWT"]' + +echo "" +echo -e "${GREEN}✅ Using sensible defaults for MQTT, Database, and other settings${NC}" + +echo "" +echo -e "${GREEN}📝 Creating .env file...${NC}" + +# Create .env file with proper structure and comments +cat > .env << EOF +# ================================================================= +# Backend API Environment Configuration +# Generated on: $(date) +# ================================================================= + +# ================================================================= +# 🔐 Security Configuration +# ================================================================= +USERNAME=$USERNAME +PASSWORD=$PASSWORD +DEBUG_MODE=$DEBUG_MODE +CONFIG_PASSWORD=$CONFIG_PASSWORD + +# ================================================================= +# 🔗 MQTT Broker Configuration (BROKER_*) +# ================================================================= +BROKER_HOST=$BROKER_HOST +BROKER_PORT=$BROKER_PORT +BROKER_USERNAME=$BROKER_USERNAME +BROKER_PASSWORD=$BROKER_PASSWORD + +# ================================================================= +# 💾 Database Configuration (DATABASE_*) +# ================================================================= +DATABASE_URL=$DATABASE_URL + +# ================================================================= +# 📊 Market Data Feed Manager Configuration (MARKET_DATA_*) +# ================================================================= +MARKET_DATA_CLEANUP_INTERVAL=$CLEANUP_INTERVAL +MARKET_DATA_FEED_TIMEOUT=$FEED_TIMEOUT + +# ================================================================= +# ☁️ AWS Configuration (AWS_*) - Optional +# ================================================================= +AWS_API_KEY=$AWS_API_KEY +AWS_SECRET_KEY=$AWS_SECRET_KEY +AWS_S3_DEFAULT_BUCKET_NAME=$S3_BUCKET + +# ================================================================= +# ⚙️ Application Settings +# ================================================================= +LOGFIRE_ENVIRONMENT=$LOGFIRE_ENV +BANNED_TOKENS=$BANNED_TOKENS + +# ================================================================= +# 📁 Legacy Settings (maintained for backward compatibility) +# ================================================================= +BOTS_PATH=$BOTS_PATH + +EOF + +echo -e "${GREEN}✅ .env file created successfully!${NC}" +echo "" + +# Display configuration summary +echo -e "${BLUE}📋 Configuration Summary${NC}" +echo "=======================" +echo -e "${CYAN}Security:${NC} Username: $USERNAME, Debug: $DEBUG_MODE" +echo -e "${CYAN}Broker:${NC} $BROKER_HOST:$BROKER_PORT" +echo -e "${CYAN}Database:${NC} ${DATABASE_URL%%@*}@[hidden]" +echo -e "${CYAN}Market Data:${NC} Cleanup: ${CLEANUP_INTERVAL}s, Timeout: ${FEED_TIMEOUT}s" +echo -e "${CYAN}Environment:${NC} $LOGFIRE_ENV" + +if [ -n "$AWS_API_KEY" ]; then + echo -e "${CYAN}AWS:${NC} Configured with S3 bucket: $S3_BUCKET" +else + echo -e "${CYAN}AWS:${NC} Not configured (optional)" +fi + +echo "" +echo -e "${GREEN}🎉 Setup Complete!${NC}" +echo "" + +# Check if password verification file exists +if [ ! -f "bots/credentials/master_account/.password_verification" ]; then + echo -e "${YELLOW}📌 Note:${NC} Password verification file will be created on first startup" + echo -e " Location: ${BLUE}bots/credentials/master_account/.password_verification${NC}" + echo "" +fi + +echo -e "Next steps:" +echo "1. Review the .env file if needed: cat .env" +echo "2. Install dependencies: make install" +echo "3. Start the API: make run" +echo "" +echo -e "${PURPLE}💡 Pro tip:${NC} You can modify environment variables in .env file anytime" +echo -e "${PURPLE}📚 Documentation:${NC} Check config.py for all available settings" +echo -e "${PURPLE}🔒 Security:${NC} The password verification file secures bot credentials" +echo "" +echo -e "${GREEN}🐳 Starting required Docker containers and pulling Hummingbot image...${NC}" + +# Run docker operations in parallel +docker compose up emqx postgres -d & +docker pull hummingbot/hummingbot:latest & + +# Wait for both operations to complete +wait + +echo -e "${GREEN}✅ All Docker operations completed!${NC}" diff --git a/bots/data/.gitignore b/test/__init__.py similarity index 100% rename from bots/data/.gitignore rename to test/__init__.py diff --git a/services/bot_archiver.py b/utils/bot_archiver.py similarity index 100% rename from services/bot_archiver.py rename to utils/bot_archiver.py diff --git a/utils/connector_manager.py b/utils/connector_manager.py new file mode 100644 index 00000000..6ae0bad6 --- /dev/null +++ b/utils/connector_manager.py @@ -0,0 +1,528 @@ +import asyncio +import logging +import time +from decimal import Decimal +from typing import Dict, List, Optional + +# Create module-specific logger +logger = logging.getLogger(__name__) + +from hummingbot.client.config.client_config_map import ClientConfigMap +from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger +from hummingbot.client.config.config_helpers import ClientConfigAdapter, ReadOnlyClientConfigAdapter, get_connector_class +from hummingbot.client.settings import AllConnectorSettings +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType +from hummingbot.core.data_type.in_flight_order import InFlightOrder, OrderState +from hummingbot.core.utils.async_utils import safe_ensure_future + +from utils.file_system import fs_util +from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter +from utils.security import BackendAPISecurity + + +class ConnectorManager: + """ + Manages the creation and caching of exchange connectors. + Handles connector configuration and initialization. + This is the single source of truth for all connector instances. + """ + + def __init__(self, secrets_manager: ETHKeyFileSecretManger, db_manager=None): + self.secrets_manager = secrets_manager + self.db_manager = db_manager + self._connector_cache: Dict[str, ConnectorBase] = {} + self._orders_recorders: Dict[str, any] = {} + self._funding_recorders: Dict[str, any] = {} + self._status_polling_tasks: Dict[str, asyncio.Task] = {} + + async def get_connector(self, account_name: str, connector_name: str): + """ + Get the connector object for the specified account and connector. + Uses caching to avoid recreating connectors unnecessarily. + Ensures proper initialization including position mode setup. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :return: The connector object. + """ + cache_key = f"{account_name}:{connector_name}" + + if cache_key in self._connector_cache: + return self._connector_cache[cache_key] + + # Create connector with full initialization + connector = await self._create_and_initialize_connector(account_name, connector_name) + return connector + + def _create_connector(self, account_name: str, connector_name: str): + """ + Create a new connector instance. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :return: The connector object. + """ + BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) + client_config_map = ClientConfigAdapter(ClientConfigMap()) + conn_setting = AllConnectorSettings.get_connector_settings()[connector_name] + keys = BackendAPISecurity.api_keys(connector_name) + + # Debug logging + logger.info(f"Creating connector {connector_name} for account {account_name}") + logger.debug(f"API keys retrieved: {list(keys.keys()) if keys else 'None'}") + + read_only_config = ReadOnlyClientConfigAdapter.lock_config(client_config_map) + + init_params = conn_setting.conn_init_parameters( + trading_pairs=[], + trading_required=True, + api_keys=keys, + client_config_map=read_only_config, + ) + + # Debug logging + logger.debug(f"Init params keys: {list(init_params.keys())}") + + connector_class = get_connector_class(connector_name) + connector = connector_class(**init_params) + return connector + + def clear_cache(self, account_name: Optional[str] = None, connector_name: Optional[str] = None): + """ + Clear the connector cache. + + :param account_name: If provided, only clear cache for this account. + :param connector_name: If provided with account_name, only clear this specific connector. + """ + if account_name and connector_name: + cache_key = f"{account_name}:{connector_name}" + self._connector_cache.pop(cache_key, None) + elif account_name: + # Clear all connectors for this account + keys_to_remove = [k for k in self._connector_cache.keys() if k.startswith(f"{account_name}:")] + for key in keys_to_remove: + self._connector_cache.pop(key) + else: + # Clear entire cache + self._connector_cache.clear() + + @staticmethod + def get_connector_config_map(connector_name: str): + """ + Get the connector config map for the specified connector. + + :param connector_name: The name of the connector. + :return: The connector config map. + """ + connector_config = HummingbotAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) + return [key for key in connector_config.hb_config.__fields__.keys() if key != "connector"] + + async def update_connector_keys(self, account_name: str, connector_name: str, keys: dict): + """ + Update the API keys for a connector and refresh the connector instance. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :param keys: Dictionary of API keys to update. + :return: The updated connector instance. + """ + BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) + connector_config = HummingbotAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) + + for key, value in keys.items(): + setattr(connector_config, key, value) + + BackendAPISecurity.update_connector_keys(account_name, connector_config) + + # Re-decrypt all credentials to ensure the new keys are available + BackendAPISecurity.decrypt_all(account_name=account_name) + + # Clear the cache for this connector to force recreation with new keys + self.clear_cache(account_name, connector_name) + + # Create and return new connector instance + new_connector = await self.get_connector(account_name, connector_name) + + return new_connector + + def list_account_connectors(self, account_name: str) -> List[str]: + """ + List all initialized connectors for a specific account. + + :param account_name: The name of the account. + :return: List of connector names. + """ + connectors = [] + for cache_key in self._connector_cache.keys(): + acc_name, conn_name = cache_key.split(":", 1) + if acc_name == account_name: + connectors.append(conn_name) + return connectors + + def get_all_connectors(self) -> Dict[str, Dict[str, ConnectorBase]]: + """ + Get all connectors organized by account. + + :return: Dictionary mapping account names to their connectors. + """ + result = {} + for cache_key, connector in self._connector_cache.items(): + account_name, connector_name = cache_key.split(":", 1) + if account_name not in result: + result[account_name] = {} + result[account_name][connector_name] = connector + return result + + def is_connector_initialized(self, account_name: str, connector_name: str) -> bool: + """ + Check if a connector is already initialized and cached. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :return: True if the connector is initialized, False otherwise. + """ + cache_key = f"{account_name}:{connector_name}" + return cache_key in self._connector_cache + + async def _create_and_initialize_connector(self, account_name: str, connector_name: str) -> ConnectorBase: + """ + Create and fully initialize a connector with all necessary setup. + This includes creating the connector, starting its network, setting up order recording, + and configuring position mode for perpetual connectors. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + :return: The initialized connector instance. + """ + cache_key = f"{account_name}:{connector_name}" + # Create the base connector + connector = self._create_connector(account_name, connector_name) + + # Initialize symbol map + await connector._initialize_trading_pair_symbol_map() + + # Update trading rules + await connector._update_trading_rules() + + # Update initial balances + await connector._update_balances() + + # Set default position mode to HEDGE for perpetual connectors + if "_perpetual" in connector_name: + if PositionMode.HEDGE in connector.supported_position_modes(): + connector.set_position_mode(PositionMode.HEDGE) + await connector._update_positions() + + self._connector_cache[cache_key] = connector + + # Load existing orders from database before starting network + if self.db_manager: + await self._load_existing_orders_from_database(connector, account_name, connector_name) + + # Start order tracking if db_manager is available + if self.db_manager: + if cache_key not in self._orders_recorders: + # Import OrdersRecorder dynamically to avoid circular imports + from services.orders_recorder import OrdersRecorder + + # Create and start orders recorder + orders_recorder = OrdersRecorder(self.db_manager, account_name, connector_name) + orders_recorder.start(connector) + self._orders_recorders[cache_key] = orders_recorder + + # Start funding tracking for perpetual connectors + if "_perpetual" in connector_name and cache_key not in self._funding_recorders: + # Import FundingRecorder dynamically to avoid circular imports + from services.funding_recorder import FundingRecorder + + # Create and start funding recorder + funding_recorder = FundingRecorder(self.db_manager, account_name, connector_name) + funding_recorder.start(connector) + self._funding_recorders[cache_key] = funding_recorder + + # Start network manually without clock system + await self._start_connector_network(connector) + + # Perform initial update of connector state + await self._update_connector_state(connector, connector_name) + + logger.info(f"Initialized connector {connector_name} for account {account_name}") + return connector + + async def _start_connector_network(self, connector: ConnectorBase): + """ + Start connector network tasks manually without clock system. + Based on the original start_network method but without order book tracker. + """ + try: + # Stop any existing network tasks + await self._stop_connector_network(connector) + + # Start trading rules polling + connector._trading_rules_polling_task = safe_ensure_future(connector._trading_rules_polling_loop()) + + # Start trading fees polling + connector._trading_fees_polling_task = safe_ensure_future(connector._trading_fees_polling_loop()) + + # Start user stream tracker (websocket connection) + connector._user_stream_tracker_task = connector._create_user_stream_tracker_task() + + # Start user stream event listener + connector._user_stream_event_listener_task = safe_ensure_future(connector._user_stream_event_listener()) + + # Start lost orders update task + connector._lost_orders_update_task = safe_ensure_future(connector._lost_orders_update_polling_loop()) + + logger.info(f"Started connector network tasks for {connector}") + + except Exception as e: + logger.error(f"Error starting connector network: {e}") + raise + + async def _stop_connector_network(self, connector: ConnectorBase): + """ + Stop connector network tasks. + """ + try: + # Stop trading rules polling + if connector._trading_rules_polling_task: + connector._trading_rules_polling_task.cancel() + connector._trading_rules_polling_task = None + + # Stop trading fees polling + if connector._trading_fees_polling_task: + connector._trading_fees_polling_task.cancel() + connector._trading_fees_polling_task = None + + # Stop status polling + if connector._status_polling_task: + connector._status_polling_task.cancel() + connector._status_polling_task = None + + # Stop user stream tracker + if connector._user_stream_tracker_task: + connector._user_stream_tracker_task.cancel() + connector._user_stream_tracker_task = None + + # Stop user stream event listener + if connector._user_stream_event_listener_task: + connector._user_stream_event_listener_task.cancel() + connector._user_stream_event_listener_task = None + + # Stop lost orders update task + if connector._lost_orders_update_task: + connector._lost_orders_update_task.cancel() + connector._lost_orders_update_task = None + + except Exception as e: + logger.error(f"Error stopping connector network: {e}") + + async def _update_connector_state(self, connector: ConnectorBase, connector_name: str): + """ + Update connector state including balances, orders, positions, and trading rules. + This function can be called both during initialization and periodically. + """ + try: + # Update balances + await connector._update_balances() + + # Update trading rules + await connector._update_trading_rules() + + # Update positions for perpetual connectors + if "_perpetual" in connector_name: + await connector._update_positions() + + # Update order status for in-flight orders + if hasattr(connector, '_update_order_status') and connector.in_flight_orders: + await connector._update_order_status() + + logger.debug(f"Updated connector state for {connector_name}") + + except Exception as e: + logger.error(f"Error updating connector state for {connector_name}: {e}") + + async def update_all_connector_states(self): + """ + Update state for all cached connectors. + This can be called periodically to refresh connector data. + """ + for cache_key, connector in self._connector_cache.items(): + account_name, connector_name = cache_key.split(":", 1) + try: + await self._update_connector_state(connector, connector_name) + except Exception as e: + logger.error(f"Error updating state for {account_name}/{connector_name}: {e}") + + async def _load_existing_orders_from_database(self, connector: ConnectorBase, account_name: str, connector_name: str): + """ + Load existing active orders from database and add them to connector's in_flight_orders. + This ensures that orders placed before an API restart can still be managed. + + :param connector: The connector instance to load orders into + :param account_name: The name of the account + :param connector_name: The name of the connector + """ + try: + # Import OrderRepository dynamically to avoid circular imports + from database import OrderRepository + + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + + # Get active orders from database for this account/connector + active_orders = await order_repo.get_active_orders(account_name=account_name, connector_name=connector_name) + + logger.info(f"Loading {len(active_orders)} existing active orders for {account_name}/{connector_name}") + + for order_record in active_orders: + try: + # Convert database order to InFlightOrder + in_flight_order = self._convert_db_order_to_in_flight_order(order_record) + + # Add to connector's in_flight_orders + connector.in_flight_orders[in_flight_order.client_order_id] = in_flight_order + + logger.debug(f"Loaded order {in_flight_order.client_order_id} from database into connector") + + except Exception as e: + logger.error(f"Error converting database order {order_record.client_order_id} to InFlightOrder: {e}") + continue + + logger.info( + f"Successfully loaded {len(connector.in_flight_orders)} in-flight orders for {account_name}/{connector_name}" + ) + + except Exception as e: + logger.error(f"Error loading existing orders from database for {account_name}/{connector_name}: {e}") + + def _convert_db_order_to_in_flight_order(self, order_record) -> InFlightOrder: + """ + Convert a database Order record to a Hummingbot InFlightOrder object. + + :param order_record: Database Order model instance + :return: InFlightOrder instance + """ + # Map database status to OrderState + status_mapping = { + "SUBMITTED": OrderState.PENDING_CREATE, + "OPEN": OrderState.OPEN, + "PARTIALLY_FILLED": OrderState.PARTIALLY_FILLED, + "FILLED": OrderState.FILLED, + "CANCELLED": OrderState.CANCELED, + "FAILED": OrderState.FAILED, + } + + # Get the appropriate OrderState + order_state = status_mapping.get(order_record.status, OrderState.PENDING_CREATE) + + # Convert string enums to proper enum instances + try: + order_type = OrderType[order_record.order_type] + except (KeyError, ValueError): + logger.warning(f"Unknown order type '{order_record.order_type}', defaulting to LIMIT") + order_type = OrderType.LIMIT + + try: + trade_type = TradeType[order_record.trade_type] + except (KeyError, ValueError): + logger.warning(f"Unknown trade type '{order_record.trade_type}', defaulting to BUY") + trade_type = TradeType.BUY + + # Convert creation timestamp - use order creation time or current time as fallback + creation_timestamp = order_record.created_at.timestamp() if order_record.created_at else time.time() + + # Create InFlightOrder instance + in_flight_order = InFlightOrder( + client_order_id=order_record.client_order_id, + trading_pair=order_record.trading_pair, + order_type=order_type, + trade_type=trade_type, + amount=Decimal(str(order_record.amount)), + creation_timestamp=creation_timestamp, + price=Decimal(str(order_record.price)) if order_record.price else None, + exchange_order_id=order_record.exchange_order_id, + initial_state=order_state, + leverage=1, # Default leverage + position=PositionAction.NIL, # Default position action + ) + + # Update current state and filled amount if order has progressed + in_flight_order.current_state = order_state + if order_record.filled_amount: + in_flight_order.executed_amount_base = Decimal(str(order_record.filled_amount)) + if order_record.average_fill_price: + in_flight_order.last_executed_quantity = Decimal(str(order_record.filled_amount or 0)) + in_flight_order.last_executed_price = Decimal(str(order_record.average_fill_price)) + + return in_flight_order + + async def stop_connector(self, account_name: str, connector_name: str): + """ + Stop a connector and its associated services. + + :param account_name: The name of the account. + :param connector_name: The name of the connector. + """ + cache_key = f"{account_name}:{connector_name}" + + # Stop order recorder if exists + if cache_key in self._orders_recorders: + try: + await self._orders_recorders[cache_key].stop() + del self._orders_recorders[cache_key] + logger.info(f"Stopped order recorder for {account_name}/{connector_name}") + except Exception as e: + logger.error(f"Error stopping order recorder for {account_name}/{connector_name}: {e}") + + # Stop funding recorder if exists + if cache_key in self._funding_recorders: + try: + await self._funding_recorders[cache_key].stop() + del self._funding_recorders[cache_key] + logger.info(f"Stopped funding recorder for {account_name}/{connector_name}") + except Exception as e: + logger.error(f"Error stopping funding recorder for {account_name}/{connector_name}: {e}") + + # Stop manual status polling task if exists + if cache_key in self._status_polling_tasks: + try: + self._status_polling_tasks[cache_key].cancel() + del self._status_polling_tasks[cache_key] + logger.info(f"Stopped manual status polling for {account_name}/{connector_name}") + except Exception as e: + logger.error(f"Error stopping manual status polling for {account_name}/{connector_name}: {e}") + + # Stop connector network if exists + if cache_key in self._connector_cache: + try: + connector = self._connector_cache[cache_key] + await self._stop_connector_network(connector) + logger.info(f"Stopped connector network for {account_name}/{connector_name}") + except Exception as e: + logger.error(f"Error stopping connector network for {account_name}/{connector_name}: {e}") + + async def stop_all_connectors(self): + """ + Stop all connectors and their associated services. + """ + # Get all account/connector pairs + pairs = [(k.split(":", 1)[0], k.split(":", 1)[1]) for k in self._connector_cache.keys()] + + # Stop each connector + for account_name, connector_name in pairs: + await self.stop_connector(account_name, connector_name) + + def list_available_credentials(self, account_name: str) -> List[str]: + """ + List all available connector credentials for an account. + + :param account_name: The name of the account. + :return: List of connector names that have credentials. + """ + try: + files = fs_util.list_files(f"credentials/{account_name}/connectors") + return [file.replace(".yml", "") for file in files if file.endswith(".yml")] + except FileNotFoundError: + return [] diff --git a/utils/etl_databases.py b/utils/etl_databases.py deleted file mode 100644 index 85efe9ee..00000000 --- a/utils/etl_databases.py +++ /dev/null @@ -1,391 +0,0 @@ -import os -import pandas as pd -import json -from typing import List, Dict, Any - -from hummingbot.core.data_type.common import TradeType -from hummingbot.strategy_v2.models.base import RunnableStatus -from hummingbot.strategy_v2.models.executors import CloseType -from hummingbot.strategy_v2.models.executors_info import ExecutorInfo -from sqlalchemy import create_engine, insert, text, MetaData, Table, Column, VARCHAR, INT, FLOAT, Integer, String, Float -from sqlalchemy.orm import sessionmaker - - -class HummingbotDatabase: - def __init__(self, db_path: str): - self.db_name = os.path.basename(db_path) - self.db_path = db_path - self.db_path = f'sqlite:///{os.path.join(db_path)}' - self.engine = create_engine(self.db_path, connect_args={'check_same_thread': False}) - self.session_maker = sessionmaker(bind=self.engine) - - @staticmethod - def _get_table_status(table_loader): - try: - data = table_loader() - return "Correct" if len(data) > 0 else f"Error - No records matched" - except Exception as e: - return f"Error - {str(e)}" - - @property - def status(self): - trade_fill_status = self._get_table_status(self.get_trade_fills) - orders_status = self._get_table_status(self.get_orders) - order_status_status = self._get_table_status(self.get_order_status) - executors_status = self._get_table_status(self.get_executors_data) - controller_status = self._get_table_status(self.get_controllers_data) - general_status = all(status == "Correct" for status in - [trade_fill_status, orders_status, order_status_status, executors_status, controller_status]) - status = {"db_name": self.db_name, - "db_path": self.db_path, - "trade_fill": trade_fill_status, - "orders": orders_status, - "order_status": order_status_status, - "executors": executors_status, - "general_status": general_status - } - return status - - def get_orders(self): - with self.session_maker() as session: - query = "SELECT * FROM 'Order'" - orders = pd.read_sql_query(text(query), session.connection()) - orders["market"] = orders["market"] - orders["amount"] = orders["amount"] / 1e6 - orders["price"] = orders["price"] / 1e6 - # orders['creation_timestamp'] = pd.to_datetime(orders['creation_timestamp'], unit="ms") - # orders['last_update_timestamp'] = pd.to_datetime(orders['last_update_timestamp'], unit="ms") - return orders - - def get_trade_fills(self): - groupers = ["config_file_path", "market", "symbol"] - float_cols = ["amount", "price", "trade_fee_in_quote"] - with self.session_maker() as session: - query = "SELECT * FROM TradeFill" - trade_fills = pd.read_sql_query(text(query), session.connection()) - trade_fills[float_cols] = trade_fills[float_cols] / 1e6 - trade_fills["cum_fees_in_quote"] = trade_fills.groupby(groupers)["trade_fee_in_quote"].cumsum() - trade_fills["trade_fee"] = trade_fills.groupby(groupers)["cum_fees_in_quote"].diff() - # trade_fills["timestamp"] = pd.to_datetime(trade_fills["timestamp"], unit="ms") - return trade_fills - - def get_order_status(self): - with self.session_maker() as session: - query = "SELECT * FROM OrderStatus" - order_status = pd.read_sql_query(text(query), session.connection()) - return order_status - - def get_executors_data(self) -> pd.DataFrame: - with self.session_maker() as session: - query = "SELECT * FROM Executors" - executors = pd.read_sql_query(text(query), session.connection()) - return executors - - def get_controllers_data(self) -> pd.DataFrame: - with self.session_maker() as session: - query = "SELECT * FROM Controllers" - controllers = pd.read_sql_query(text(query), session.connection()) - return controllers - - -class ETLPerformance: - def __init__(self, - db_path: str): - self.db_path = f'sqlite:///{os.path.join(db_path)}' - self.engine = create_engine(self.db_path, connect_args={'check_same_thread': False}) - self.session_maker = sessionmaker(bind=self.engine) - self.metadata = MetaData() - - @property - def executors_table(self): - return Table('executors', - MetaData(), - Column('id', String), - Column('timestamp', Integer), - Column('type', String), - Column('close_type', Integer), - Column('close_timestamp', Integer), - Column('status', String), - Column('config', String), - Column('net_pnl_pct', Float), - Column('net_pnl_quote', Float), - Column('cum_fees_quote', Float), - Column('filled_amount_quote', Float), - Column('is_active', Integer), - Column('is_trading', Integer), - Column('custom_info', String), - Column('controller_id', String)) - - @property - def trade_fill_table(self): - return Table( - 'trades', MetaData(), - Column('config_file_path', VARCHAR(255)), - Column('strategy', VARCHAR(255)), - Column('market', VARCHAR(255)), - Column('symbol', VARCHAR(255)), - Column('base_asset', VARCHAR(255)), - Column('quote_asset', VARCHAR(255)), - Column('timestamp', INT), - Column('order_id', VARCHAR(255)), - Column('trade_type', VARCHAR(255)), - Column('order_type', VARCHAR(255)), - Column('price', FLOAT), - Column('amount', FLOAT), - Column('leverage', INT), - Column('trade_fee', VARCHAR(255)), - Column('trade_fee_in_quote', FLOAT), - Column('exchange_trade_id', VARCHAR(255)), - Column('position', VARCHAR(255)), - ) - - @property - def orders_table(self): - return Table( - 'orders', MetaData(), - Column('client_order_id', VARCHAR(255)), - Column('config_file_path', VARCHAR(255)), - Column('strategy', VARCHAR(255)), - Column('market', VARCHAR(255)), - Column('symbol', VARCHAR(255)), - Column('base_asset', VARCHAR(255)), - Column('quote_asset', VARCHAR(255)), - Column('creation_timestamp', INT), - Column('order_type', VARCHAR(255)), - Column('amount', FLOAT), - Column('leverage', INT), - Column('price', FLOAT), - Column('last_status', VARCHAR(255)), - Column('last_update_timestamp', INT), - Column('exchange_order_id', VARCHAR(255)), - Column('position', VARCHAR(255)), - ) - - @property - def controllers_table(self): - return Table( - 'controllers', MetaData(), - Column('id', VARCHAR(255)), - Column('controller_id', INT), - Column('timestamp', FLOAT), - Column('type', VARCHAR(255)), - Column('config', String), - ) - - @property - def tables(self): - return [self.executors_table, self.trade_fill_table, self.orders_table, self.controllers_table] - - def create_tables(self): - with self.engine.connect(): - for table in self.tables: - table.create(self.engine) - - def insert_data(self, data): - if "executors" in data: - self.insert_executors(data["executors"]) - if "trade_fill" in data: - self.insert_trade_fill(data["trade_fill"]) - if "orders" in data: - self.insert_orders(data["orders"]) - if "controllers" in data: - self.insert_controllers(data["controllers"]) - - def insert_executors(self, executors): - with self.engine.connect() as conn: - for _, row in executors.iterrows(): - ins = self.executors_table.insert().values( - id=row["id"], - timestamp=row["timestamp"], - type=row["type"], - close_type=row["close_type"], - close_timestamp=row["close_timestamp"], - status=row["status"], - config=row["config"], - net_pnl_pct=row["net_pnl_pct"], - net_pnl_quote=row["net_pnl_quote"], - cum_fees_quote=row["cum_fees_quote"], - filled_amount_quote=row["filled_amount_quote"], - is_active=row["is_active"], - is_trading=row["is_trading"], - custom_info=row["custom_info"], - controller_id=row["controller_id"]) - conn.execute(ins) - conn.commit() - - def insert_trade_fill(self, trade_fill): - with self.engine.connect() as conn: - for _, row in trade_fill.iterrows(): - ins = insert(self.trade_fill_table).values( - config_file_path=row["config_file_path"], - strategy=row["strategy"], - market=row["market"], - symbol=row["symbol"], - base_asset=row["base_asset"], - quote_asset=row["quote_asset"], - timestamp=row["timestamp"], - order_id=row["order_id"], - trade_type=row["trade_type"], - order_type=row["order_type"], - price=row["price"], - amount=row["amount"], - leverage=row["leverage"], - trade_fee=row["trade_fee"], - trade_fee_in_quote=row["trade_fee_in_quote"], - exchange_trade_id=row["exchange_trade_id"], - position=row["position"], - ) - conn.execute(ins) - conn.commit() - - def insert_orders(self, orders): - with self.engine.connect() as conn: - for _, row in orders.iterrows(): - ins = insert(self.orders_table).values( - client_order_id=row["id"], - config_file_path=row["config_file_path"], - strategy=row["strategy"], - market=row["market"], - symbol=row["symbol"], - base_asset=row["base_asset"], - quote_asset=row["quote_asset"], - creation_timestamp=row["creation_timestamp"], - order_type=row["order_type"], - amount=row["amount"], - leverage=row["leverage"], - price=row["price"], - last_status=row["last_status"], - last_update_timestamp=row["last_update_timestamp"], - exchange_order_id=row["exchange_order_id"], - position=row["position"], - ) - conn.execute(ins) - conn.commit() - - def insert_controllers(self, controllers): - with self.engine.connect() as conn: - for _, row in controllers.iterrows(): - ins = insert(self.controllers_table).values( - id=row["id"], - controller_id=row["controller_id"], - timestamp=row["timestamp"], - type=row["type"], - config=row["config"], - ) - conn.execute(ins) - conn.commit() - - def load_executors(self): - with self.session_maker() as session: - query = "SELECT * FROM executors" - executors = pd.read_sql_query(text(query), session.connection()) - return executors - - def load_trade_fill(self): - with self.session_maker() as session: - query = "SELECT * FROM trades" - trade_fill = pd.read_sql_query(text(query), session.connection()) - return trade_fill - - def load_orders(self): - with self.session_maker() as session: - query = "SELECT * FROM orders" - orders = pd.read_sql_query(text(query), session.connection()) - return orders - - def load_controllers(self): - with self.session_maker() as session: - query = "SELECT * FROM controllers" - controllers = pd.read_sql_query(text(query), session.connection()) - return controllers - - -class PerformanceDataSource: - def __init__(self, executors_dict: Dict[str, Any]): - self.executors_dict = executors_dict - - @property - def executors_df(self): - executors = pd.DataFrame(self.executors_dict) - executors["custom_info"] = executors["custom_info"].apply( - lambda x: json.loads(x) if isinstance(x, str) else x) - executors["config"] = executors["config"].apply(lambda x: json.loads(x) if isinstance(x, str) else x) - executors["timestamp"] = executors["timestamp"].apply(lambda x: self.ensure_timestamp_in_seconds(x)) - executors["close_timestamp"] = executors["close_timestamp"].apply( - lambda x: self.ensure_timestamp_in_seconds(x)) - executors["trading_pair"] = executors["config"].apply(lambda x: x["trading_pair"]) - executors["exchange"] = executors["config"].apply(lambda x: x["connector_name"]) - executors["level_id"] = executors["config"].apply(lambda x: x.get("level_id")) - executors["bep"] = executors["custom_info"].apply(lambda x: x["current_position_average_price"]) - executors["order_ids"] = executors["custom_info"].apply(lambda x: x.get("order_ids")) - executors["close_price"] = executors["custom_info"].apply(lambda x: x.get("close_price", x["current_position_average_price"])) - executors["sl"] = executors["config"].apply(lambda x: x.get("stop_loss")).fillna(0) - executors["tp"] = executors["config"].apply(lambda x: x.get("take_profit")).fillna(0) - executors["tl"] = executors["config"].apply(lambda x: x.get("time_limit")).fillna(0) - return executors - - @property - def executor_info_list(self) -> List[ExecutorInfo]: - executors = self.apply_special_data_types(self.executors_df) - executor_values = [] - for index, row in executors.iterrows(): - executor_to_append = ExecutorInfo( - id=row["id"], - timestamp=row["timestamp"], - type=row["type"], - close_timestamp=row["close_timestamp"], - close_type=row["close_type"], - status=row["status"], - config=row["config"], - net_pnl_pct=row["net_pnl_pct"], - net_pnl_quote=row["net_pnl_quote"], - cum_fees_quote=row["cum_fees_quote"], - filled_amount_quote=row["filled_amount_quote"], - is_active=row["is_active"], - is_trading=row["is_trading"], - custom_info=row["custom_info"], - controller_id=row["controller_id"] - ) - executor_to_append.custom_info["side"] = row["side"] - executor_values.append(executor_to_append) - return executor_values - - def apply_special_data_types(self, executors): - executors["status"] = executors["status"].apply(lambda x: self.get_enum_by_value(RunnableStatus, int(x))) - executors["side"] = executors["config"].apply(lambda x: self.get_enum_by_value(TradeType, int(x["side"]))) - executors["close_type"] = executors["close_type"].apply(lambda x: self.get_enum_by_value(CloseType, int(x))) - executors["close_type_name"] = executors["close_type"].apply(lambda x: x.name) - executors["datetime"] = pd.to_datetime(executors.timestamp, unit="s") - executors["close_datetime"] = pd.to_datetime(executors["close_timestamp"], unit="s") - return executors - - @staticmethod - def get_enum_by_value(enum_class, value): - for member in enum_class: - if member.value == value: - return member - raise ValueError(f"No enum member with value {value}") - - @staticmethod - def ensure_timestamp_in_seconds(timestamp: float) -> float: - """ - Ensure the given timestamp is in seconds. - Args: - - timestamp (int): The input timestamp which could be in seconds, milliseconds, or microseconds. - Returns: - - int: The timestamp in seconds. - Raises: - - ValueError: If the timestamp is not in a recognized format. - """ - timestamp_int = int(float(timestamp)) - if timestamp_int >= 1e18: # Nanoseconds - return timestamp_int / 1e9 - elif timestamp_int >= 1e15: # Microseconds - return timestamp_int / 1e6 - elif timestamp_int >= 1e12: # Milliseconds - return timestamp_int / 1e3 - elif timestamp_int >= 1e9: # Seconds - return timestamp_int - else: - raise ValueError( - "Timestamp is not in a recognized format. Must be in seconds, milliseconds, microseconds or nanoseconds.") \ No newline at end of file diff --git a/utils/file_system.py b/utils/file_system.py index 040126f8..76250ee5 100644 --- a/utils/file_system.py +++ b/utils/file_system.py @@ -2,10 +2,13 @@ import inspect import logging import os + +# Create module-specific logger +logger = logging.getLogger(__name__) import shutil import sys from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Type import yaml from hummingbot.client.config.config_data_types import BaseClientModel @@ -19,25 +22,49 @@ class FileSystemUtil: """ FileSystemUtil provides utility functions for file and directory management, as well as dynamic loading of script configurations. + + All file operations are performed relative to the base_path unless an absolute path is provided. + Implements singleton pattern to ensure the same instance is reused. """ + _instance = None base_path: str = "bots" # Default base path + def __new__(cls, base_path: Optional[str] = None): + if cls._instance is None: + cls._instance = super(FileSystemUtil, cls).__new__(cls) + cls._instance.base_path = base_path if base_path else "bots" + return cls._instance + def __init__(self, base_path: Optional[str] = None): """ Initializes the FileSystemUtil with a base path. :param base_path: The base directory path for file operations. """ - if base_path: - self.base_path = base_path + # Singleton pattern - instance already configured in __new__ + pass + + def _get_full_path(self, path: str) -> str: + """ + Get the full path by combining base_path with relative path. + :param path: Relative or absolute path. + :return: Full absolute path. + """ + return path if os.path.isabs(path) else os.path.join(self.base_path, path) def list_files(self, directory: str) -> List[str]: """ Lists all files in a given directory. :param directory: The directory to list files from. :return: List of file names in the directory. + :raises FileNotFoundError: If the directory does not exist. + :raises PermissionError: If access is denied to the directory. """ excluded_files = ["__init__.py", "__pycache__", ".DS_Store", ".dockerignore", ".gitignore"] - dir_path = os.path.join(self.base_path, directory) + dir_path = self._get_full_path(directory) + if not os.path.exists(dir_path): + raise FileNotFoundError(f"Directory '{directory}' not found") + if not os.path.isdir(dir_path): + raise NotADirectoryError(f"Path '{directory}' is not a directory") return [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) and f not in excluded_files] def list_folders(self, directory: str) -> List[str]: @@ -45,62 +72,97 @@ def list_folders(self, directory: str) -> List[str]: Lists all folders in a given directory. :param directory: The directory to list folders from. :return: List of folder names in the directory. - """ - dir_path = os.path.join(self.base_path, directory) + :raises FileNotFoundError: If the directory does not exist. + :raises PermissionError: If access is denied to the directory. + """ + dir_path = self._get_full_path(directory) + if not os.path.exists(dir_path): + raise FileNotFoundError(f"Directory '{directory}' not found") + if not os.path.isdir(dir_path): + raise NotADirectoryError(f"Path '{directory}' is not a directory") return [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))] - def create_folder(self, directory: str, folder_name: str): + def create_folder(self, directory: str, folder_name: str) -> None: """ Creates a folder in a specified directory. :param directory: The directory to create the folder in. :param folder_name: The name of the folder to be created. + :raises PermissionError: If permission is denied to create the folder. + :raises OSError: If there's an OS-level error creating the folder. """ - folder_path = os.path.join(self.base_path, directory, folder_name) + if not folder_name or '/' in folder_name or '\\' in folder_name: + raise ValueError(f"Invalid folder name: '{folder_name}'") + folder_path = self._get_full_path(os.path.join(directory, folder_name)) os.makedirs(folder_path, exist_ok=True) - def copy_folder(self, src: str, dest: str): + def copy_folder(self, src: str, dest: str) -> None: """ Copies a folder to a new location. :param src: The source folder to copy. :param dest: The destination folder to copy to. - """ - src_path = os.path.join(self.base_path, src) - dest_path = os.path.join(self.base_path, dest) - os.makedirs(dest_path, exist_ok=True) - for item in os.listdir(src_path): - s = os.path.join(src_path, item) - d = os.path.join(dest_path, item) - if os.path.isdir(s): - self.copy_folder(s, d) - else: - shutil.copy2(s, d) + :raises FileNotFoundError: If source folder doesn't exist. + :raises PermissionError: If permission is denied. + """ + src_path = self._get_full_path(src) + dest_path = self._get_full_path(dest) + + if not os.path.exists(src_path): + raise FileNotFoundError(f"Source folder '{src}' not found") + if not os.path.isdir(src_path): + raise NotADirectoryError(f"Source path '{src}' is not a directory") + + shutil.copytree(src_path, dest_path, dirs_exist_ok=True) - def copy_file(self, src: str, dest: str): + def copy_file(self, src: str, dest: str) -> None: """ Copies a file to a new location. :param src: The source file to copy. :param dest: The destination file to copy to. - """ - src_path = os.path.join(self.base_path, src) - dest_path = os.path.join(self.base_path, dest) + :raises FileNotFoundError: If source file doesn't exist. + :raises PermissionError: If permission is denied. + """ + src_path = self._get_full_path(src) + dest_path = self._get_full_path(dest) + + if not os.path.exists(src_path): + raise FileNotFoundError(f"Source file '{src}' not found") + if os.path.isdir(src_path): + raise IsADirectoryError(f"Source path '{src}' is a directory, not a file") + + # Ensure destination directory exists + dest_dir = os.path.dirname(dest_path) + os.makedirs(dest_dir, exist_ok=True) + shutil.copy2(src_path, dest_path) - def delete_folder(self, directory: str, folder_name: str): + def delete_folder(self, directory: str, folder_name: str) -> None: """ Deletes a folder in a specified directory. :param directory: The directory to delete the folder from. :param folder_name: The name of the folder to be deleted. - """ - folder_path = os.path.join(self.base_path, directory, folder_name) + :raises FileNotFoundError: If folder doesn't exist. + :raises PermissionError: If permission is denied. + """ + folder_path = self._get_full_path(os.path.join(directory, folder_name)) + if not os.path.exists(folder_path): + raise FileNotFoundError(f"Folder '{folder_name}' not found in '{directory}'") + if not os.path.isdir(folder_path): + raise NotADirectoryError(f"Path '{folder_name}' is not a directory") shutil.rmtree(folder_path) - def delete_file(self, directory: str, file_name: str): + def delete_file(self, directory: str, file_name: str) -> None: """ Deletes a file in a specified directory. :param directory: The directory to delete the file from. :param file_name: The name of the file to be deleted. - """ - file_path = os.path.join(self.base_path, directory, file_name) + :raises FileNotFoundError: If file doesn't exist. + :raises PermissionError: If permission is denied. + """ + file_path = self._get_full_path(os.path.join(directory, file_name)) + if not os.path.exists(file_path): + raise FileNotFoundError(f"File '{file_name}' not found in '{directory}'") + if os.path.isdir(file_path): + raise IsADirectoryError(f"Path '{file_name}' is a directory, not a file") os.remove(file_path) def path_exists(self, path: str) -> bool: @@ -109,56 +171,101 @@ def path_exists(self, path: str) -> bool: :param path: The path to check. :return: True if the path exists, False otherwise. """ - return os.path.exists(os.path.join(self.base_path, path)) + return os.path.exists(self._get_full_path(path)) - def add_file(self, directory: str, file_name: str, content: str, override: bool = False): + def add_file(self, directory: str, file_name: str, content: str, override: bool = False) -> None: """ Adds a file to a specified directory. :param directory: The directory to add the file to. :param file_name: The name of the file to be added. :param content: The content to be written to the file. :param override: If True, override the file if it exists. - """ - file_path = os.path.join(self.base_path, directory, file_name) + :raises ValueError: If file_name is invalid. + :raises FileExistsError: If file exists and override is False. + :raises PermissionError: If permission is denied to write the file. + """ + if not file_name or '/' in file_name or '\\' in file_name: + raise ValueError(f"Invalid file name: '{file_name}'") + + dir_path = self._get_full_path(directory) + os.makedirs(dir_path, exist_ok=True) + + file_path = os.path.join(dir_path, file_name) if not override and os.path.exists(file_path): raise FileExistsError(f"File '{file_name}' already exists in '{directory}'.") - with open(file_path, 'w') as file: + + with open(file_path, 'w', encoding='utf-8') as file: file.write(content) - def append_to_file(self, directory: str, file_name: str, content: str): + def append_to_file(self, directory: str, file_name: str, content: str) -> None: """ Appends content to a specified file. :param directory: The directory containing the file. :param file_name: The name of the file to append to. :param content: The content to append to the file. - """ - file_path = os.path.join(self.base_path, directory, file_name) - with open(file_path, 'a') as file: + :raises FileNotFoundError: If file doesn't exist. + :raises PermissionError: If permission is denied. + """ + file_path = self._get_full_path(os.path.join(directory, file_name)) + if not os.path.exists(file_path): + raise FileNotFoundError(f"File '{file_name}' not found in '{directory}'") + if os.path.isdir(file_path): + raise IsADirectoryError(f"Path '{file_name}' is a directory, not a file") + + with open(file_path, 'a', encoding='utf-8') as file: file.write(content) - @staticmethod - def dump_dict_to_yaml(filename, data_dict): + def read_file(self, file_path: str) -> str: + """ + Reads the content of a file. + :param file_path: The relative path to the file from base_path. + :return: The content of the file as a string. + :raises FileNotFoundError: If the file does not exist. + :raises PermissionError: If access is denied to the file. + :raises IsADirectoryError: If the path points to a directory. + """ + full_path = self._get_full_path(file_path) + if not os.path.exists(full_path): + raise FileNotFoundError(f"File '{file_path}' not found") + if os.path.isdir(full_path): + raise IsADirectoryError(f"Path '{file_path}' is a directory, not a file") + + with open(full_path, 'r', encoding='utf-8') as file: + return file.read() + + def dump_dict_to_yaml(self, filename: str, data_dict: dict) -> None: """ Dumps a dictionary to a YAML file. + :param filename: The file to dump the dictionary into (relative to base_path). :param data_dict: The dictionary to dump. - :param filename: The file to dump the dictionary into. + :raises PermissionError: If permission is denied to write the file. """ - with open(filename, 'w') as file: - yaml.dump(data_dict, file) + file_path = self._get_full_path(filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w', encoding='utf-8') as file: + yaml.dump(data_dict, file, default_flow_style=False, allow_unicode=True) - @staticmethod - def read_yaml_file(file_path): + def read_yaml_file(self, file_path: str) -> dict: """ Reads a YAML file and returns the data as a dictionary. - :param file_path: The path to the YAML file. + :param file_path: The path to the YAML file (relative to base_path or absolute). :return: Dictionary containing the YAML file data. - """ - with open(file_path, 'r') as file: - data = yaml.safe_load(file) - return data + :raises FileNotFoundError: If the file doesn't exist. + :raises yaml.YAMLError: If the YAML is invalid. + """ + full_path = self._get_full_path(file_path) if not os.path.isabs(file_path) else file_path + if not os.path.exists(full_path): + raise FileNotFoundError(f"YAML file '{file_path}' not found") + + with open(full_path, 'r', encoding='utf-8') as file: + try: + data = yaml.safe_load(file) + return data if data is not None else {} + except yaml.YAMLError as e: + raise yaml.YAMLError(f"Invalid YAML in file '{file_path}': {e}") @staticmethod - def load_script_config_class(script_name): + def load_script_config_class(script_name: str) -> Optional[Type[BaseClientModel]]: """ Dynamically loads a script's configuration class. :param script_name: The name of the script file (without the '.py' extension). @@ -176,14 +283,15 @@ def load_script_config_class(script_name): for _, cls in inspect.getmembers(script_module, inspect.isclass): if issubclass(cls, BaseClientModel) and cls is not BaseClientModel: return cls - except Exception as e: - print(f"Error loading script class: {e}") # Handle or log the error appropriately + except (ImportError, AttributeError, ModuleNotFoundError) as e: + logger.warning(f"Error loading script class for '{script_name}': {e}") return None @staticmethod - def load_controller_config_class(controller_type: str, controller_name: str): + def load_controller_config_class(controller_type: str, controller_name: str) -> Optional[Type]: """ Dynamically loads a controller's configuration class. + :param controller_type: The type of the controller. :param controller_name: The name of the controller file (without the '.py' extension). :return: The configuration class from the controller, or None if not found. """ @@ -201,51 +309,145 @@ def load_controller_config_class(controller_type: str, controller_name: str): or (issubclass(cls, MarketMakingControllerConfigBase) and cls is not MarketMakingControllerConfigBase)\ or (issubclass(cls, ControllerConfigBase) and cls is not ControllerConfigBase): return cls - except Exception as e: - print(f"Error loading controller class: {e}") + except (ImportError, AttributeError, ModuleNotFoundError) as e: + logger.warning(f"Error loading controller class for '{controller_type}.{controller_name}': {e}") + return None - @staticmethod - def ensure_file_and_dump_text(file_path, text): + def ensure_file_and_dump_text(self, file_path: str, text: str) -> None: """ - Ensures that the directory for the file exists, then dumps the dictionary to a YAML file. - :param file_path: The file path to dump the dictionary into. - :param text: The text to dump. + Ensures that the directory for the file exists, then writes text to a file. + :param file_path: The file path to write to (relative to base_path or absolute). + :param text: The text to write. + :raises PermissionError: If permission is denied. """ - os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "w") as f: + full_path = self._get_full_path(file_path) if not os.path.isabs(file_path) else file_path + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding='utf-8') as f: f.write(text) - @staticmethod - # TODO: make paths relative - def get_connector_keys_path(account_name: str, connector_name: str) -> Path: - return Path(f"bots/credentials/{account_name}/connectors/{connector_name}.yml") + def get_connector_keys_path(self, account_name: str, connector_name: str) -> Path: + """ + Get the path to connector credentials file. + :param account_name: Name of the account. + :param connector_name: Name of the connector. + :return: Path to the connector credentials file. + """ + return Path("credentials") / account_name / "connectors" / f"{connector_name}.yml" - @staticmethod - def save_model_to_yml(yml_path: Path, cm: ClientConfigAdapter): + def save_model_to_yml(self, yml_path: str, cm: ClientConfigAdapter) -> None: + """ + Save a ClientConfigAdapter model to a YAML file. + :param yml_path: Path to the YAML file (relative to base_path or absolute). + :param cm: The ClientConfigAdapter to save. + :raises PermissionError: If permission is denied to write the file. + """ try: + full_path = self._get_full_path(yml_path) cm_yml_str = cm.generate_yml_output_str_with_comments() - with open(yml_path, "w", encoding="utf-8") as outfile: + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding="utf-8") as outfile: outfile.write(cm_yml_str) except Exception as e: - logging.error("Error writing configs: %s" % (str(e),), exc_info=True) + logger.error(f"Error writing configs to '{yml_path}': {e}", exc_info=True) + raise - def list_databases(self): - archived_path = os.path.join(self.base_path, "archived") - archived_instances = self.list_folders("archived") + def get_base_path(self) -> str: + """ + Returns the base path for file operations + :return: The base path string + """ + return self.base_path + + def get_directory_creation_time(self, path): + """ + Get the creation time of a directory + :param path: The path to the directory + :return: ISO formatted creation time string or None if directory doesn't exist + """ + import os + import datetime + + full_path = self._get_full_path(path) + if not os.path.exists(full_path): + return None + + # Get creation time (platform dependent) + try: + # For Unix systems, use stat + creation_time = os.stat(full_path).st_ctime + # Convert to datetime + return datetime.datetime.fromtimestamp(creation_time).isoformat() + except Exception: + # Fallback + return "unknown" + + def list_directories(self, path): + """ + List all directories within a given path + :param path: The path to list directories from + :return: List of directory names + """ + import os + + full_path = self._get_full_path(path) + if not os.path.exists(full_path): + return [] + + try: + # Return only directories + return [d for d in os.listdir(full_path) if os.path.isdir(os.path.join(full_path, d))] + except Exception: + return [] + + def list_databases(self) -> List[str]: + """ + Lists all database files in archived instances + :return: List of database file paths + """ + try: + archived_instances = self.list_folders("archived") + except FileNotFoundError: + return [] + archived_databases = [] for archived_instance in archived_instances: - db_path = os.path.join(archived_path, archived_instance, "data") - archived_databases += [os.path.join(db_path, db_file) for db_file in os.listdir(db_path) - if db_file.endswith(".sqlite")] + db_path = self._get_full_path(os.path.join("archived", archived_instance, "data")) + try: + if os.path.exists(db_path): + archived_databases.extend([ + os.path.join(db_path, db_file) + for db_file in os.listdir(db_path) + if db_file.endswith(".sqlite") + ]) + except (OSError, PermissionError) as e: + logger.warning(f"Error accessing database path '{db_path}': {e}") return archived_databases - def list_checkpoints(self, full_path: bool): - dir_path = os.path.join(self.base_path, "data") - if full_path: - checkpoints = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if - os.path.isfile(os.path.join(dir_path, f)) - and f.startswith("checkpoint") and f.endswith(".sqlite")] - else: - checkpoints = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) - and f.startswith("checkpoint") and f.endswith(".sqlite")] - return checkpoints + def list_checkpoints(self, full_path: bool = False) -> List[str]: + """ + Lists all checkpoint database files + :param full_path: If True, return full paths, otherwise just filenames + :return: List of checkpoint database files + """ + dir_path = self._get_full_path("data") + if not os.path.exists(dir_path): + return [] + + try: + files = os.listdir(dir_path) + checkpoint_files = [ + f for f in files + if (os.path.isfile(os.path.join(dir_path, f)) + and f.startswith("checkpoint") + and f.endswith(".sqlite")) + ] + + if full_path: + return [os.path.join(dir_path, f) for f in checkpoint_files] + else: + return checkpoint_files + except (OSError, PermissionError) as e: + logger.warning(f"Error listing checkpoints in '{dir_path}': {e}") + return [] + +fs_util = FileSystemUtil() \ No newline at end of file diff --git a/utils/models.py b/utils/hummingbot_api_config_adapter.py similarity index 98% rename from utils/models.py rename to utils/hummingbot_api_config_adapter.py index 7e49da31..4dce67a3 100644 --- a/utils/models.py +++ b/utils/hummingbot_api_config_adapter.py @@ -4,7 +4,7 @@ from pydantic import SecretStr -class BackendAPIConfigAdapter(ClientConfigAdapter): +class HummingbotAPIConfigAdapter(ClientConfigAdapter): def _encrypt_secrets(self, conf_dict: Dict[str, Any]): from utils.security import BackendAPISecurity for attr, value in conf_dict.items(): diff --git a/utils/hummingbot_database_reader.py b/utils/hummingbot_database_reader.py new file mode 100644 index 00000000..6a57d260 --- /dev/null +++ b/utils/hummingbot_database_reader.py @@ -0,0 +1,308 @@ +import os +import pandas as pd +import json +from typing import List, Dict, Any + +from hummingbot.core.data_type.common import TradeType +from hummingbot.strategy_v2.models.base import RunnableStatus +from hummingbot.strategy_v2.models.executors import CloseType +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo +from sqlalchemy import create_engine, insert, text, MetaData, Table, Column, VARCHAR, INT, FLOAT, Integer, String, Float +from sqlalchemy.orm import sessionmaker + + +class HummingbotDatabase: + def __init__(self, db_path: str): + self.db_name = os.path.basename(db_path) + self.db_path = db_path + self.db_path = f'sqlite:///{os.path.join(db_path)}' + self.engine = create_engine(self.db_path, connect_args={'check_same_thread': False}) + self.session_maker = sessionmaker(bind=self.engine) + + @staticmethod + def _get_table_status(table_loader): + try: + data = table_loader() + return "Correct" if len(data) > 0 else f"Error - No records matched" + except Exception as e: + return f"Error - {str(e)}" + + @property + def status(self): + trade_fill_status = self._get_table_status(self.get_trade_fills) + orders_status = self._get_table_status(self.get_orders) + order_status_status = self._get_table_status(self.get_order_status) + executors_status = self._get_table_status(self.get_executors_data) + controller_status = self._get_table_status(self.get_controllers_data) + positions_status = self._get_table_status(self.get_positions) + general_status = all(status == "Correct" for status in + [trade_fill_status, orders_status, order_status_status, executors_status, controller_status, positions_status]) + status = {"db_name": self.db_name, + "db_path": self.db_path, + "trade_fill": trade_fill_status, + "orders": orders_status, + "order_status": order_status_status, + "executors": executors_status, + "controllers": controller_status, + "positions": positions_status, + "general_status": general_status + } + return status + + def get_orders(self): + with self.session_maker() as session: + query = "SELECT * FROM 'Order'" + orders = pd.read_sql_query(text(query), session.connection()) + orders["amount"] = orders["amount"] / 1e6 + orders["price"] = orders["price"] / 1e6 + orders.rename(columns={"market": "connector_name", "symbol": "trading_pair"}, inplace=True) + return orders + + def get_trade_fills(self): + groupers = ["config_file_path", "connector_name", "trading_pair"] + float_cols = ["amount", "price", "trade_fee_in_quote"] + with self.session_maker() as session: + query = "SELECT * FROM TradeFill" + trade_fills = pd.read_sql_query(text(query), session.connection()) + trade_fills.rename(columns={"market": "connector_name", "symbol": "trading_pair"}, inplace=True) + trade_fills[float_cols] = trade_fills[float_cols] / 1e6 + trade_fills["cum_fees_in_quote"] = trade_fills.groupby(groupers)["trade_fee_in_quote"].cumsum() + trade_fills["trade_fee"] = trade_fills.groupby(groupers)["cum_fees_in_quote"].diff() + return trade_fills + + def get_order_status(self): + with self.session_maker() as session: + query = "SELECT * FROM OrderStatus" + order_status = pd.read_sql_query(text(query), session.connection()) + return order_status + + def get_executors_data(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Executors" + executors = pd.read_sql_query(text(query), session.connection()) + return executors + + def get_controllers_data(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Controllers" + controllers = pd.read_sql_query(text(query), session.connection()) + return controllers + + def get_positions(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Position" + positions = pd.read_sql_query(text(query), session.connection()) + # Convert decimal fields from stored format (divide by 1e6) + decimal_cols = ["volume_traded_quote", "amount", "breakeven_price", "unrealized_pnl_quote", "cum_fees_quote"] + positions[decimal_cols] = positions[decimal_cols] / 1e6 + return positions + + def calculate_trade_based_performance(self) -> pd.DataFrame: + """ + Calculate trade-based performance metrics using vectorized pandas operations. + + Returns: + DataFrame with rolling performance metrics calculated per trading pair. + """ + # Get trade fills data + trades = self.get_trade_fills() + + if len(trades) == 0: + return pd.DataFrame() + + # Sort by timestamp to ensure proper rolling calculation + trades = trades.sort_values(['trading_pair', 'connector_name', 'timestamp']).copy() + + # Create buy/sell indicator columns + trades['is_buy'] = (trades['trade_type'].str.upper() == 'BUY').astype(int) + trades['is_sell'] = (trades['trade_type'].str.upper() == 'SELL').astype(int) + + # Calculate buy and sell amounts and values vectorized + trades['buy_amount'] = trades['amount'] * trades['is_buy'] + trades['sell_amount'] = trades['amount'] * trades['is_sell'] + trades['buy_value'] = trades['price'] * trades['amount'] * trades['is_buy'] + trades['sell_value'] = trades['price'] * trades['amount'] * trades['is_sell'] + + # Group by trading_pair and connector_name for rolling calculations + grouper = ['trading_pair', 'connector_name'] + + # Calculate cumulative volumes and values + trades['buy_volume'] = trades.groupby(grouper)['buy_amount'].cumsum() + trades['sell_volume'] = trades.groupby(grouper)['sell_amount'].cumsum() + trades['buy_value_cum'] = trades.groupby(grouper)['buy_value'].cumsum() + trades['sell_value_cum'] = trades.groupby(grouper)['sell_value'].cumsum() + + # Calculate average prices (avoid division by zero) + trades['buy_avg_price'] = trades['buy_value_cum'] / trades['buy_volume'].replace(0, pd.NA) + trades['sell_avg_price'] = trades['sell_value_cum'] / trades['sell_volume'].replace(0, pd.NA) + + # Forward fill average prices within each group to handle NaN values + trades['buy_avg_price'] = trades.groupby(grouper)['buy_avg_price'].ffill().fillna(0) + trades['sell_avg_price'] = trades.groupby(grouper)['sell_avg_price'].ffill().fillna(0) + + # Calculate net position + trades['net_position'] = trades['buy_volume'] - trades['sell_volume'] + + # Calculate realized PnL + trades['realized_trade_pnl_pct'] = ( + (trades['sell_avg_price'] - trades['buy_avg_price']) / trades['buy_avg_price'] + ).fillna(0) + + # Matched volume for realized PnL (minimum of buy and sell volumes) + trades['matched_volume'] = pd.concat([trades['buy_volume'], trades['sell_volume']], axis=1).min(axis=1) + trades['realized_trade_pnl_quote'] = trades['realized_trade_pnl_pct'] * trades['matched_volume'] * trades['buy_avg_price'] + + # Calculate unrealized PnL based on position direction + # For long positions (net_position > 0): use current price vs buy_avg_price + # For short positions (net_position < 0): use sell_avg_price vs current price + trades['unrealized_trade_pnl_pct'] = 0.0 + + # Long positions + long_mask = trades['net_position'] > 0 + trades.loc[long_mask, 'unrealized_trade_pnl_pct'] = ( + (trades.loc[long_mask, 'price'] - trades.loc[long_mask, 'buy_avg_price']) / + trades.loc[long_mask, 'buy_avg_price'] + ).fillna(0) + + # Short positions + short_mask = trades['net_position'] < 0 + trades.loc[short_mask, 'unrealized_trade_pnl_pct'] = ( + (trades.loc[short_mask, 'sell_avg_price'] - trades.loc[short_mask, 'price']) / + trades.loc[short_mask, 'sell_avg_price'] + ).fillna(0) + + # Calculate unrealized PnL in quote currency + trades['unrealized_trade_pnl_quote'] = 0.0 + + # Long positions: use buy_avg_price as reference + long_mask = trades['net_position'] > 0 + trades.loc[long_mask, 'unrealized_trade_pnl_quote'] = ( + trades.loc[long_mask, 'unrealized_trade_pnl_pct'] * + trades.loc[long_mask, 'net_position'].abs() * + trades.loc[long_mask, 'buy_avg_price'] + ) + + # Short positions: use sell_avg_price as reference + short_mask = trades['net_position'] < 0 + trades.loc[short_mask, 'unrealized_trade_pnl_quote'] = ( + trades.loc[short_mask, 'unrealized_trade_pnl_pct'] * + trades.loc[short_mask, 'net_position'].abs() * + trades.loc[short_mask, 'sell_avg_price'] + ) + + # Fees are already in trade_fee_in_quote column + trades['fees_quote'] = trades['trade_fee_in_quote'] + + # Calculate net PnL + trades['net_pnl_quote'] = ( + trades['realized_trade_pnl_quote'] + + trades['unrealized_trade_pnl_quote'] - + trades['fees_quote'] + ) + + # Calculate cumulative volume in quote currency + trades['volume_quote'] = trades['price'] * trades['amount'] + trades['cum_volume_quote'] = trades.groupby(grouper)['volume_quote'].cumsum() + + # Select and return relevant columns + result_columns = [ + 'timestamp', 'price', 'amount', 'trade_type', 'trading_pair', 'connector_name', + 'buy_avg_price', 'buy_volume', 'sell_avg_price', 'sell_volume', + 'net_position', 'realized_trade_pnl_pct', 'realized_trade_pnl_quote', + 'unrealized_trade_pnl_pct', 'unrealized_trade_pnl_quote', + 'fees_quote', 'net_pnl_quote', 'volume_quote', 'cum_volume_quote' + ] + + return trades[result_columns].sort_values('timestamp') + + + +class PerformanceDataSource: + def __init__(self, executors_dict: Dict[str, Any]): + self.executors_dict = executors_dict + + @property + def executors_df(self): + executors = pd.DataFrame(self.executors_dict) + executors["custom_info"] = executors["custom_info"].apply( + lambda x: json.loads(x) if isinstance(x, str) else x) + executors["config"] = executors["config"].apply(lambda x: json.loads(x) if isinstance(x, str) else x) + executors["timestamp"] = executors["timestamp"].apply(lambda x: self.ensure_timestamp_in_seconds(x)) + executors["close_timestamp"] = executors["close_timestamp"].apply( + lambda x: self.ensure_timestamp_in_seconds(x)) + executors["trading_pair"] = executors["config"].apply(lambda x: x["trading_pair"]) + executors["exchange"] = executors["config"].apply(lambda x: x["connector_name"]) + executors["level_id"] = executors["config"].apply(lambda x: x.get("level_id")) + executors["bep"] = executors["custom_info"].apply(lambda x: x["current_position_average_price"]) + executors["order_ids"] = executors["custom_info"].apply(lambda x: x.get("order_ids")) + executors["close_price"] = executors["custom_info"].apply(lambda x: x.get("close_price", x["current_position_average_price"])) + executors["sl"] = executors["config"].apply(lambda x: x.get("stop_loss")).fillna(0) + executors["tp"] = executors["config"].apply(lambda x: x.get("take_profit")).fillna(0) + executors["tl"] = executors["config"].apply(lambda x: x.get("time_limit")).fillna(0) + return executors + + @property + def executor_info_list(self) -> List[ExecutorInfo]: + executors = self.apply_special_data_types(self.executors_df) + executor_values = [] + for index, row in executors.iterrows(): + executor_to_append = ExecutorInfo( + id=row["id"], + timestamp=row["timestamp"], + type=row["type"], + close_timestamp=row["close_timestamp"], + close_type=row["close_type"], + status=row["status"], + config=row["config"], + net_pnl_pct=row["net_pnl_pct"], + net_pnl_quote=row["net_pnl_quote"], + cum_fees_quote=row["cum_fees_quote"], + filled_amount_quote=row["filled_amount_quote"], + is_active=row["is_active"], + is_trading=row["is_trading"], + custom_info=row["custom_info"], + controller_id=row["controller_id"] + ) + executor_to_append.custom_info["side"] = row["side"] + executor_values.append(executor_to_append) + return executor_values + + def apply_special_data_types(self, executors): + executors["status"] = executors["status"].apply(lambda x: self.get_enum_by_value(RunnableStatus, int(x))) + executors["side"] = executors["config"].apply(lambda x: self.get_enum_by_value(TradeType, int(x["side"]))) + executors["close_type"] = executors["close_type"].apply(lambda x: self.get_enum_by_value(CloseType, int(x))) + executors["close_type_name"] = executors["close_type"].apply(lambda x: x.name) + executors["datetime"] = pd.to_datetime(executors.timestamp, unit="s") + executors["close_datetime"] = pd.to_datetime(executors["close_timestamp"], unit="s") + return executors + + @staticmethod + def get_enum_by_value(enum_class, value): + for member in enum_class: + if member.value == value: + return member + raise ValueError(f"No enum member with value {value}") + + @staticmethod + def ensure_timestamp_in_seconds(timestamp: float) -> float: + """ + Ensure the given timestamp is in seconds. + Args: + - timestamp (int): The input timestamp which could be in seconds, milliseconds, or microseconds. + Returns: + - int: The timestamp in seconds. + Raises: + - ValueError: If the timestamp is not in a recognized format. + """ + timestamp_int = int(float(timestamp)) + if timestamp_int >= 1e18: # Nanoseconds + return timestamp_int / 1e9 + elif timestamp_int >= 1e15: # Microseconds + return timestamp_int / 1e6 + elif timestamp_int >= 1e12: # Milliseconds + return timestamp_int / 1e3 + elif timestamp_int >= 1e9: # Seconds + return timestamp_int + else: + raise ValueError( + "Timestamp is not in a recognized format. Must be in seconds, milliseconds, microseconds or nanoseconds.") \ No newline at end of file diff --git a/utils/mqtt_manager.py b/utils/mqtt_manager.py new file mode 100644 index 00000000..e2a1473d --- /dev/null +++ b/utils/mqtt_manager.py @@ -0,0 +1,546 @@ +import asyncio +import json +import logging +import time +from collections import defaultdict, deque +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, Optional, Set + +import aiomqtt + +logger = logging.getLogger(__name__) + + +class MQTTManager: + """ + Manages MQTT connections and message handling for Hummingbot bot communication. + Uses asyncio-mqtt (aiomqtt) for asynchronous MQTT operations. + """ + + def __init__(self, host: str, port: int, username: str, password: str): + self.host = host + self.port = port + self.username = username + self.password = password + + # Message handlers by topic pattern + self._handlers: Dict[str, Callable] = {} + + # Bot data storage + self._bot_performance: Dict[str, Dict] = defaultdict(dict) + self._bot_logs: Dict[str, deque] = defaultdict(lambda: deque(maxlen=100)) + self._bot_error_logs: Dict[str, deque] = defaultdict(lambda: deque(maxlen=100)) + + # Auto-discovered bots + self._discovered_bots: Dict[str, float] = {} # bot_id: last_seen_timestamp + + # Message deduplication tracking + self._processed_messages: Dict[str, float] = {} # message_hash: timestamp + self._message_ttl = 300 # 5 minutes TTL for processed messages + + # Connection state + self._connected = False + self._reconnect_interval = 5 # seconds + self._client: Optional[aiomqtt.Client] = None + self._tasks: Set[asyncio.Task] = set() + + # RPC response tracking + self._pending_responses: Dict[str, asyncio.Future] = {} # reply_to_topic: future + + # Subscriptions to restore on reconnect + self._subscriptions = [ + ("hbot/+/log", 1), # Log messages + ("hbot/+/notify", 1), # Notifications + ("hbot/+/status_updates", 1), # Status updates + ("hbot/+/events", 1), # Internal events + ("hbot/+/hb", 1), # Heartbeats + ("hbot/+/performance", 1), # Performance metrics + ("hbot/+/external/event/+", 1), # External events + ("hummingbot-api/response/+", 1), # RPC responses to our reply_to topics + ] + + if username: + logger.info(f"MQTT client configured for user: {username}") + else: + logger.info("MQTT client configured without authentication") + + @asynccontextmanager + async def _get_client(self): + """Get MQTT client for a single connection attempt.""" + client_id = f"hummingbot-api-{int(time.time())}" + + # Create client with credentials if provided + if self.username and self.password: + client = aiomqtt.Client( + hostname=self.host, + port=self.port, + username=self.username, + password=self.password, + identifier=client_id, + keepalive=60, + ) + else: + client = aiomqtt.Client(hostname=self.host, port=self.port, identifier=client_id, keepalive=60) + + async with client: + self._connected = True + logger.info(f"✓ Connected to MQTT broker at {self.host}:{self.port}") + + # Subscribe to topics + for topic, qos in self._subscriptions: + await client.subscribe(topic, qos=qos) + yield client + + # Cleanup on exit + self._connected = False + + async def _handle_messages(self): + """Main message handling loop with reconnection.""" + while True: + try: + async with self._get_client() as client: + self._client = client + async for message in client.messages: + await self._process_message(message) + except aiomqtt.MqttError as error: + logger.error(f'MQTT disconnected during message iteration: "{error}". Reconnecting...') + await asyncio.sleep(self._reconnect_interval) + except Exception as e: + logger.error(f"Unexpected error in message handler: {e}. Reconnecting...") + await asyncio.sleep(self._reconnect_interval) + + async def _process_message(self, message): + """Process incoming MQTT message.""" + try: + topic = str(message.topic) + + # Check if this is an RPC response to our hummingbot-api + if topic.startswith("hummingbot-api/response/"): + await self._handle_rpc_response(topic, message) + return + + topic_parts = topic.split("/") + + # Check if this matches namespace/instance_id/channel pattern + if len(topic_parts) >= 3: + namespace, bot_id, channel = topic_parts[0], topic_parts[1], "/".join(topic_parts[2:]) + # Only process if it's the expected namespace + if namespace == "hbot": + # Auto-discover bot + self._discovered_bots[bot_id] = time.time() + # Parse message + try: + data = json.loads(message.payload.decode("utf-8")) + except json.JSONDecodeError: + data = message.payload.decode("utf-8") + + # Route to appropriate handler based on Hummingbot's topics + if channel == "log": + await self._handle_log(bot_id, data) + elif channel == "notify": + await self._handle_notify(bot_id, data) + elif channel == "status_updates": + await self._handle_status(bot_id, data) + elif channel == "hb": # heartbeat + await self._handle_heartbeat(bot_id, data) + elif channel == "events": + await self._handle_events(bot_id, data) + elif channel == "performance": + await self._handle_performance(bot_id, data) + elif channel.startswith("response/"): + await self._handle_command_response(bot_id, channel, data) + elif channel.startswith("external/event/"): + await self._handle_external_event(bot_id, channel, data) + elif channel in ["history", "start", "stop", "config", "import_strategy"]: + # These are command channels - responses should come on response/* topics + logger.debug(f"Command channel '{channel}' for bot {bot_id} - waiting for response") + else: + logger.info(f"Unknown channel '{channel}' for bot {bot_id}") + + # Call custom handlers + for pattern, handler in self._handlers.items(): + if self._match_topic(pattern, topic): + if asyncio.iscoroutinefunction(handler): + await handler(bot_id, channel, data) + else: + # Run sync handler in executor + await asyncio.get_event_loop().run_in_executor(None, handler, bot_id, channel, data) + except Exception as e: + logger.error(f"Error processing message from {message.topic}: {e}", exc_info=True) + + def _match_topic(self, pattern: str, topic: str) -> bool: + """Check if topic matches pattern (supports + wildcard).""" + pattern_parts = pattern.split("/") + topic_parts = topic.split("/") + + if len(pattern_parts) != len(topic_parts): + return False + + for p, t in zip(pattern_parts, topic_parts): + if p != "+" and p != t: + return False + return True + + async def _handle_performance(self, bot_id: str, data: Any): + """Handle performance updates.""" + if isinstance(data, dict): + for controller_id, performance in data.items(): + if bot_id not in self._bot_performance: + self._bot_performance[bot_id] = {} + self._bot_performance[bot_id][controller_id] = performance + + async def _handle_log(self, bot_id: str, data: Any): + """Handle log messages with deduplication.""" + # Create a unique message identifier for deduplication + if isinstance(data, dict): + level = data.get("level_name") or data.get("levelname") or data.get("level", "INFO") + message = data.get("msg") or data.get("message", "") + timestamp = data.get("timestamp") or data.get("time") or time.time() + + # Create hash for deduplication (bot_id + message + timestamp within 1 second) + message_hash = f"{bot_id}:{message}:{int(timestamp)}" + elif isinstance(data, str): + message = data + timestamp = time.time() + level = "INFO" + + # Create hash for string messages + message_hash = f"{bot_id}:{message}:{int(timestamp)}" + else: + return # Skip invalid data + + # Check for duplicates + current_time = time.time() + if message_hash in self._processed_messages: + # Skip duplicate message + logger.debug(f"Skipping duplicate log message from {bot_id}: {message[:50]}...") + return + + # Clean up old message hashes (older than TTL) + expired_hashes = [h for h, t in self._processed_messages.items() if current_time - t > self._message_ttl] + for h in expired_hashes: + del self._processed_messages[h] + + # Record this message as processed + self._processed_messages[message_hash] = current_time + + # Process the message + if isinstance(data, dict): + # Normalize the log entry + log_entry = { + "level_name": level, + "msg": message, + "timestamp": timestamp, + **data, # Include all original fields + } + + if level.upper() == "ERROR": + self._bot_error_logs[bot_id].append(log_entry) + else: + self._bot_logs[bot_id].append(log_entry) + elif isinstance(data, str): + # Handle plain string logs + log_entry = {"level_name": "INFO", "msg": data, "timestamp": timestamp} + self._bot_logs[bot_id].append(log_entry) + + async def _handle_notify(self, bot_id: str, data: Any): + """Handle notification messages.""" + # Store notifications if needed + + async def _handle_status(self, bot_id: str, data: Any): + """Handle status updates.""" + # Store latest status + + async def _handle_heartbeat(self, bot_id: str, data: Any): + """Handle heartbeat messages.""" + self._discovered_bots[bot_id] = time.time() # Update last seen + + async def _handle_events(self, bot_id: str, data: Any): + """Handle internal events.""" + # Process events as needed + + async def _handle_external_event(self, bot_id: str, channel: str, data: Any): + """Handle external events.""" + event_type = channel.split("/")[-1] + + async def _handle_rpc_response(self, topic: str, message): + """Handle RPC responses on hummingbot-api/response/* topics.""" + try: + # Parse the response data + try: + data = json.loads(message.payload.decode("utf-8")) + except json.JSONDecodeError: + data = message.payload.decode("utf-8") + + # Check if we have a pending response for this topic + if topic in self._pending_responses: + future = self._pending_responses.pop(topic) + if not future.done(): + future.set_result(data) + else: + logger.warning(f"No pending RPC response found for topic: {topic}") + + except Exception as e: + logger.error(f"Error handling RPC response on {topic}: {e}", exc_info=True) + + async def _handle_command_response(self, bot_id: str, channel: str, data: Any): + """Handle command responses (legacy - keeping for backward compatibility).""" + # Extract command from response channel (e.g., response/start/1234567890 or response/history) + channel_parts = channel.split("/") + if len(channel_parts) >= 2: + command = channel_parts[1] + + async def start(self): + """Start the MQTT client.""" + try: + # Create and store the main message handling task + task = asyncio.create_task(self._handle_messages()) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + + logger.info("MQTT client started") + + # Wait a bit for connection to establish + for i in range(10): + if self._connected: + logger.info("MQTT connection established successfully") + break + await asyncio.sleep(0.5) + else: + logger.warning("MQTT connection not established after 5 seconds") + + except Exception as e: + logger.error(f"Failed to start MQTT client: {e}", exc_info=True) + + async def stop(self): + """Stop the MQTT client.""" + self._connected = False + + # Cancel all running tasks + for task in self._tasks: + task.cancel() + + # Wait for all tasks to complete + await asyncio.gather(*self._tasks, return_exceptions=True) + + logger.info("MQTT client stopped") + + async def publish_command_and_wait( + self, bot_id: str, command: str, data: Dict[str, Any], timeout: float = 30.0, qos: int = 1 + ) -> Optional[Any]: + """ + Publish a command to a bot and wait for the response. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data + :param timeout: Timeout in seconds to wait for response + :param qos: Quality of Service level + :return: Response data if received, None if timeout or error + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return None + + # Generate unique reply_to topic + timestamp = int(time.time() * 1000) + reply_to_topic = f"hummingbot-api/response/{timestamp}" + + # Create a future to track the response using the reply_to topic as key + future = asyncio.Future() + self._pending_responses[reply_to_topic] = future + + try: + # Send the command with custom reply_to + success = await self._publish_command_with_reply_to(bot_id, command, data, reply_to_topic, qos) + if not success: + self._pending_responses.pop(reply_to_topic, None) + return None + + # Wait for response with timeout + try: + response = await asyncio.wait_for(future, timeout=timeout) + return response + except asyncio.TimeoutError: + logger.warning(f"⏰ Timeout waiting for response from {bot_id} for command '{command}' on {reply_to_topic}") + self._pending_responses.pop(reply_to_topic, None) + return None + + except Exception as e: + logger.error(f"Error sending command and waiting for response: {e}") + self._pending_responses.pop(reply_to_topic, None) + return None + + async def _publish_command_with_reply_to( + self, bot_id: str, command: str, data: Dict[str, Any], reply_to: str, qos: int = 1 + ) -> bool: + """ + Publish a command to a bot with custom reply_to topic. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data + :param reply_to: Custom reply_to topic + :param qos: Quality of Service level + :return: True if published successfully + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return False + + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Use the correct topic for each command + topic = f"hbot/{mqtt_bot_id}/{command}" + + # Create the full RPC message structure with custom reply_to + message = { + "header": { + "timestamp": int(time.time() * 1000), # Milliseconds + "reply_to": reply_to, # Custom reply_to topic + "msg_id": int(time.time() * 1000), + "node_id": "hummingbot-api", + "agent": "hummingbot-api", + "properties": {}, + }, + "data": data or {}, + } + + try: + await self._client.publish(topic, payload=json.dumps(message), qos=qos) + return True + except Exception as e: + logger.error(f"Failed to publish command to {bot_id}: {e}") + return False + + async def publish_command(self, bot_id: str, command: str, data: Dict[str, Any], qos: int = 1) -> bool: + """ + Publish a command to a bot using proper RPCMessage Request format. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data (should match the specific CommandMessage.Request structure) + :param qos: Quality of Service level + :return: True if published successfully + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return False + + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Use the correct topic for each command + topic = f"hbot/{mqtt_bot_id}/{command}" + + # Create the full RPC message structure as expected by commlib + # Based on RPCClient._prepare_request method + message = { + "header": { + "timestamp": int(time.time() * 1000), # Milliseconds + "reply_to": f"hummingbot-api-response-{int(time.time() * 1000)}", # Unique response topic + "msg_id": int(time.time() * 1000), + "node_id": "hummingbot-api", + "agent": "hummingbot-api", + "properties": {}, + }, + "data": data or {}, + } + + try: + await self._client.publish(topic, payload=json.dumps(message), qos=qos) + return True + except Exception as e: + logger.error(f"Failed to publish command to {bot_id}: {e}") + return False + + def add_handler(self, topic_pattern: str, handler: Callable): + """ + Add a custom message handler for a topic pattern. + + :param topic_pattern: Topic pattern (supports + wildcard) + :param handler: Callback function(bot_id, channel, data) - can be sync or async + """ + self._handlers[topic_pattern] = handler + + def remove_handler(self, topic_pattern: str): + """Remove a message handler.""" + self._handlers.pop(topic_pattern, None) + + def get_bot_performance(self, bot_id: str) -> Dict[str, Any]: + """Get performance data for a bot.""" + return self._bot_performance.get(bot_id, {}) + + def get_bot_logs(self, bot_id: str) -> list: + """Get recent logs for a bot.""" + return list(self._bot_logs.get(bot_id, [])) + + def get_bot_error_logs(self, bot_id: str) -> list: + """Get recent error logs for a bot.""" + return list(self._bot_error_logs.get(bot_id, [])) + + def clear_bot_data(self, bot_id: str): + """Clear stored data for a bot.""" + self._bot_performance.pop(bot_id, None) + self._bot_logs.pop(bot_id, None) + self._bot_error_logs.pop(bot_id, None) + self._discovered_bots.pop(bot_id, None) + + def clear_bot_performance(self, bot_id: str): + """Clear only performance data for a bot (useful when bot is stopped).""" + self._bot_performance.pop(bot_id, None) + + @property + def is_connected(self) -> bool: + """Check if connected to MQTT broker.""" + return self._connected + + def get_discovered_bots(self, timeout_seconds: int = 300) -> list: + """Get list of auto-discovered bots. + + :param timeout_seconds: Consider bots inactive after this many seconds without messages + :return: List of active bot IDs + """ + current_time = time.time() + active_bots = [ + bot_id for bot_id, last_seen in self._discovered_bots.items() if current_time - last_seen < timeout_seconds + ] + return active_bots + + async def subscribe_to_bot(self, bot_id: str): + """Subscribe to all topics for a specific bot.""" + if self._connected and self._client: + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Subscribe to all topics for this specific bot + topic = f"hbot/{mqtt_bot_id}/#" + await self._client.subscribe(topic, qos=1) + else: + logger.warning(f"Cannot subscribe to bot {bot_id} - not connected to MQTT") + + +if __name__ == "__main__": + # Example usage + import sys + + # For Windows compatibility + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + logging.basicConfig(level=logging.INFO) + + async def main(): + mqtt_manager = MQTTManager(host="localhost", port=1883, username="", password="") + + await mqtt_manager.start() + + try: + # Keep running to listen for messages + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await mqtt_manager.stop() + + asyncio.run(main()) diff --git a/utils/security.py b/utils/security.py index 22c89aa4..c64bd823 100644 --- a/utils/security.py +++ b/utils/security.py @@ -10,14 +10,12 @@ ) from hummingbot.client.config.security import Security -from config import PASSWORD_VERIFICATION_PATH -from utils.file_system import FileSystemUtil -from utils.models import BackendAPIConfigAdapter +from config import settings +from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter +from utils.file_system import fs_util class BackendAPISecurity(Security): - fs_util = FileSystemUtil(base_path="bots/credentials") - @classmethod def login_account(cls, account_name: str, secrets_manager: BaseSecretsManager) -> bool: if not cls.validate_password(secrets_manager): @@ -30,10 +28,10 @@ def login_account(cls, account_name: str, secrets_manager: BaseSecretsManager) - def decrypt_all(cls, account_name: str = "master_account"): cls._secure_configs.clear() cls._decryption_done.clear() - encrypted_files = [file for file in cls.fs_util.list_files(directory=f"{account_name}/connectors") if + encrypted_files = [file for file in fs_util.list_files(directory=f"credentials/{account_name}/connectors") if file.endswith(".yml")] for file in encrypted_files: - path = Path(cls.fs_util.base_path + f"/{account_name}/connectors/" + file) + path = Path(fs_util.base_path + f"/credentials/{account_name}/connectors/" + file) cls.decrypt_connector_config(path) cls._decryption_done.set() @@ -43,36 +41,33 @@ def decrypt_connector_config(cls, file_path: Path): cls._secure_configs[connector_name] = cls.load_connector_config_map_from_file(file_path) @classmethod - def load_connector_config_map_from_file(cls, yml_path: Path) -> BackendAPIConfigAdapter: + def load_connector_config_map_from_file(cls, yml_path: Path) -> HummingbotAPIConfigAdapter: config_data = read_yml_file(yml_path) connector_name = connector_name_from_file(yml_path) hb_config = get_connector_hb_config(connector_name).model_validate(config_data) - config_map = BackendAPIConfigAdapter(hb_config) + config_map = HummingbotAPIConfigAdapter(hb_config) config_map.decrypt_all_secure_data() return config_map @classmethod def update_connector_keys(cls, account_name: str, connector_config: ClientConfigAdapter): connector_name = connector_config.connector - file_path = cls.fs_util.get_connector_keys_path(account_name=account_name, connector_name=connector_name) + file_path = fs_util.get_connector_keys_path(account_name=account_name, connector_name=connector_name) cm_yml_str = connector_config.generate_yml_output_str_with_comments() - cls.fs_util.ensure_file_and_dump_text(file_path, cm_yml_str) + fs_util.ensure_file_and_dump_text(str(file_path), cm_yml_str) update_connector_hb_config(connector_config) cls._secure_configs[connector_name] = connector_config @staticmethod def new_password_required() -> bool: - return not PASSWORD_VERIFICATION_PATH.exists() - - @staticmethod - def store_password_verification(secrets_manager: BaseSecretsManager): - encrypted_word = secrets_manager.encrypt_secret_value(PASSWORD_VERIFICATION_WORD, PASSWORD_VERIFICATION_WORD) - FileSystemUtil.ensure_file_and_dump_text(PASSWORD_VERIFICATION_PATH, encrypted_word) + full_path = fs_util._get_full_path(settings.app.password_verification_path) + return not Path(full_path).exists() @staticmethod def validate_password(secrets_manager: BaseSecretsManager) -> bool: valid = False - with open(PASSWORD_VERIFICATION_PATH, "r") as f: + full_path = fs_util._get_full_path(settings.app.password_verification_path) + with open(full_path, "r") as f: encrypted_word = f.read() try: decrypted_word = secrets_manager.decrypt_secret_value(PASSWORD_VERIFICATION_WORD, encrypted_word) @@ -81,3 +76,8 @@ def validate_password(secrets_manager: BaseSecretsManager) -> bool: if str(e) != "MAC mismatch": raise e return valid + + @staticmethod + def store_password_verification(secrets_manager: BaseSecretsManager): + encrypted_word = secrets_manager.encrypt_secret_value(PASSWORD_VERIFICATION_WORD, PASSWORD_VERIFICATION_WORD) + fs_util.ensure_file_and_dump_text(settings.app.password_verification_path, encrypted_word)