diff --git a/api/app/limiter_v2/README.md b/api/app/limiter_v2/README.md new file mode 100644 index 000000000..ba037dd67 --- /dev/null +++ b/api/app/limiter_v2/README.md @@ -0,0 +1,64 @@ +# Limiter V2 + +## Usage + +note: + +**limiter_v2.limit decorator must be used before cache.cached decorator** + + +```python +from limiter_v2 import limiter_v2, require_api_key, get_limits + + +# require_api_key, user default limits +@explorer_namespace.route("/v1/some_resource") +class SomeResource(Resource): + @require_api_key + @cache.cached(timeout=300, query_string=True) + def get(self): + return {"message": "Hello, world!"}, 200 + + +# require_api_key, user custom limits +@explorer_namespace.route("/v1/some_resource") +class SomeResource(Resource): + @require_api_key + @limiter_v2.limit(get_limits) + @cache.cached(timeout=300, query_string=True) + def get(self): + return {"message": "Hello, world!"}, 200 + + +# require_api_key, user custom limits with cost +@explorer_namespace.route("/v1/some_resource") +class SomeResource(Resource): + @require_api_key + @limiter_v2.limit(get_limits, cost=2) + @cache.cached(timeout=300, query_string=True) + def get(self): + return {"message": "Hello, world!"}, 200 + +# require_api_key, user custom limits with cost, if user has no limits, use default limits +@explorer_namespace.route("/v1/some_resource") +class SomeResource(Resource): + @require_api_key + @limiter_v2.limit(get_limits, cost=2, override_defaults=False) + @cache.cached(timeout=300, query_string=True) + def get(self): + return {"message": "Hello, world!"}, 200 +``` + +## add new limits + +generate new api key and insert into db + +limits format: +``` +1/second, 100/hour +1000/day +30 per minute +``` +multiple limits are supported, use comma to separate + + diff --git a/api/app/limiter_v2/limiter.py b/api/app/limiter_v2/limiter.py new file mode 100644 index 000000000..7a2b118a4 --- /dev/null +++ b/api/app/limiter_v2/limiter.py @@ -0,0 +1,65 @@ +from datetime import datetime, timezone +from functools import wraps + +from flask import jsonify, make_response, request +from flask_limiter import Limiter + +from api.app import cache +from common.models import db +from common.models.limiter import ApiKey + + +def get_header_api_key(): + return request.headers.get("X-API-KEY", "") + + +limiter_v2 = Limiter( + key_func=get_header_api_key, + default_limits=["100 per hour"], + storage_uri="memory://", +) + + +def require_api_key(f): + @wraps(f) + def decorated(*args, **kwargs): + api_key = get_header_api_key() + if not get_api_key(api_key): + return make_response(jsonify({"error": "Invalid API key"}), 403) + return f(*args, **kwargs) + + return decorated + + +def get_api_key(api_key): + cache_key = f"ak_{api_key}" + api_key_from_cache = cache.cache.get(cache_key) + if api_key_from_cache: + # if id is -1, api key not found in db + if api_key_from_cache.id == -1: + return None + return api_key_from_cache + + api_key_from_db = ( + db.session.query(ApiKey) + .filter(ApiKey.api_key == api_key, ApiKey.expires_at > datetime.now(timezone.utc)) + .first() + ) + + if api_key_from_db: + cache.cache.set(cache_key, api_key_from_db, 600) + return api_key_from_db + + # if api key not found in db, set it in cache to avoid future db hits + cache.cache.set(cache_key, ApiKey(id=-1), 300) + return None + + +def get_limits(): + api_key = get_header_api_key() + api_key_model = get_api_key(api_key) + + if api_key_model: + return api_key_model.limits + + return [] diff --git a/api/app/main.py b/api/app/main.py index f5f902887..80d8bc75e 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -7,6 +7,7 @@ from api.app.cache import cache, redis_db from api.app.limiter import limiter +from api.app.limiter_v2.limiter import limiter_v2 from common.models import db from common.utils.config import get_config from common.utils.exception_control import APIError @@ -59,6 +60,7 @@ # Rate limit limiter.init_app(app) +limiter_v2.init_app(app) # ma.init_app(app) CORS(app) diff --git a/common/models/limiter.py b/common/models/limiter.py new file mode 100644 index 000000000..06f3d5fe4 --- /dev/null +++ b/common/models/limiter.py @@ -0,0 +1,16 @@ +from sqlalchemy import TEXT, Column, func +from sqlalchemy.dialects.postgresql import BIGINT, JSONB, TIMESTAMP, VARCHAR + +from common.models import HemeraModel + + +class ApiKey(HemeraModel): + __tablename__ = "api_key" + + id = Column(BIGINT, primary_key=True, autoincrement=True) + api_key = Column(VARCHAR(255), unique=True) + limits = Column(TEXT) + expires_at = Column(TIMESTAMP) + description = Column(VARCHAR(255)) + created_at = Column(TIMESTAMP, default=func.now()) + updated_at = Column(TIMESTAMP, default=func.now(), onupdate=func.now())