diff --git a/.env.template b/.env.template index 3e57dd5d..0da5c78e 100644 --- a/.env.template +++ b/.env.template @@ -51,6 +51,10 @@ TLS_SSL_VERIFICATION=0 # if need to disable tls/ssl verification APP_PROBE_STATUSES=Running,Deleted # PROBE_PF=38123 # specify a port for port-forwaring, or for more apps: PROBE_PF=sp-status:38121,shiny-probe-test:38123 +# Rate limiting configuration for the auth endpoint +AUTH_RATE_LIMIT_VALUE= +AUTH_RATE_LIMIT_WHITELIST= + # Invenio API INVENIO_URL= INVENIO_API_TOKEN= diff --git a/studio/settings.py b/studio/settings.py index b94c0729..923c4154 100644 --- a/studio/settings.py +++ b/studio/settings.py @@ -352,6 +352,11 @@ "DEFAULT_PARSER_CLASSES": ("rest_framework.parsers.JSONParser",), } +# Rate limit whitelist for certain IP ranges on the auth endpoint +AUTH_RATE_LIMIT_VALUE = os.environ.get("AUTH_RATE_LIMIT_VALUE", None) +AUTH_RATE_LIMIT_WHITELIST_RAW = os.environ.get("AUTH_RATE_LIMIT_WHITELIST", "") +AUTH_RATE_LIMIT_WHITELIST = [ip.strip() for ip in AUTH_RATE_LIMIT_WHITELIST_RAW.split(",")] + # Tagulous serialization settings SERIALIZATION_MODULES = { "xml": "tagulous.serializers.xml_serializer", diff --git a/studio/throttle.py b/studio/throttle.py new file mode 100644 index 00000000..2ba8bc28 --- /dev/null +++ b/studio/throttle.py @@ -0,0 +1,54 @@ +from ipaddress import ip_address, ip_network +from typing import Any + +from django.conf import settings +from django.http import HttpRequest +from rest_framework.throttling import UserRateThrottle +from rest_framework.views import APIView + + +class WhitelistThrottleFilter(UserRateThrottle): + """ + Custom throttle filter that whitelists certain IP ranges + """ + + rate = getattr(settings, "AUTH_RATE_LIMIT_VALUE", None) + + def get_ident(self, request: HttpRequest) -> Any: + """ + Extract the real client IP from proxy headers + """ + + # Try X-Forwarded-For first (standard proxy header) + xff = request.META.get("HTTP_X_FORWARDED_FOR") + if xff: + ip = xff.split(",")[0].strip() + return ip + + # Try X-Real-IP (nginx specific) + real_ip = request.META.get("HTTP_X_REAL_IP") + if real_ip: + return real_ip + + # Fallback to Django's standard remote address + fallback = request.META.get("REMOTE_ADDR", "unknown") + return fallback + + def allow_request(self, request: HttpRequest, view: APIView) -> Any: + # If no rate is configured, throttling is disabled entirely + if not self.rate: + return True + + whitelist_range = getattr(settings, "AUTH_RATE_LIMIT_WHITELIST", None) + + # If whitelist is configured, check if IP is whitelisted + if whitelist_range: + incoming_ip = self.get_ident(request) + for network in whitelist_range if isinstance(whitelist_range, list) else [whitelist_range]: + try: + if ip_address(incoming_ip) in ip_network(network): + return True # Whitelisted, allow through + except ValueError: + continue # Skip invalid network/IP formats + + return super().allow_request(request, view) diff --git a/studio/views.py b/studio/views.py index a9562266..fd346319 100644 --- a/studio/views.py +++ b/studio/views.py @@ -23,6 +23,7 @@ from common.tasks import send_email_task from models.models import Model from projects.models import Project +from studio.throttle import WhitelistThrottleFilter from studio.utils import get_logger from .helpers import do_delete_account @@ -112,6 +113,7 @@ class AuthView(APIView): authentication_classes = [ModifiedSessionAuthentication, TokenAuthentication] permission_classes = [IsAuthenticated, AccessPermission] content_negotiation_class = IgnoreClientContentNegotiation + throttle_classes = [WhitelistThrottleFilter] def get(self, request: Response, format: str | None = None) -> Response: content = {