diff --git a/.dockerignore b/.dockerignore index a4edae32..90968381 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,3 +5,13 @@ static/*.fits.gz venv/ .git/ viz-example/ +tests/ +__pycache__/ +*.pyc +*.pyo +*.pyd +.pytest_cache/ +.coverage +.idea/ +.vscode/ +*.egg-info/ diff --git a/.github/workflows/fastapi-tests.yml b/.github/workflows/fastapi-tests.yml new file mode 100644 index 00000000..effdfcf1 --- /dev/null +++ b/.github/workflows/fastapi-tests.yml @@ -0,0 +1,103 @@ +name: FastAPI Tests + +on: + push: + branches: [ master, fastapi ] + pull_request: + branches: [ master, fastapi ] + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python for test dependencies + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install test dependencies + run: | + python -m pip install --upgrade pip + pip install -r tests/requirements.txt + + - name: Set up environment variables + run: | + echo "DB_USER=treasuremap" >> $GITHUB_ENV + echo "DB_PWD=treasuremap" >> $GITHUB_ENV + echo "DB_NAME=treasuremap" >> $GITHUB_ENV + echo "MAIL_PASSWORD=dummy" >> $GITHUB_ENV + echo "RECAPTCHA_PUBLIC_KEY=dummy" >> $GITHUB_ENV + echo "RECAPTCHA_PRIVATE_KEY=dummy" >> $GITHUB_ENV + echo "ZENODO_ACCESS_KEY=dummy" >> $GITHUB_ENV + echo "AWS_ACCESS_KEY_ID=dummy" >> $GITHUB_ENV + echo "AWS_SECRET_ACCESS_KEY=dummy" >> $GITHUB_ENV + + - name: Build FastAPI Docker image + run: | + docker build -f server/Dockerfile -t gwtm_fastapi:latest . + + - name: Start database with Docker Compose + run: | + docker compose up -d db + + - name: Wait for database to be ready + run: | + timeout 120 bash -c 'until docker compose exec -T db pg_isready -U treasuremap -d treasuremap; do sleep 5; done' + + - name: Enable PostGIS extension + run: | + docker compose exec -T db psql -U treasuremap -d treasuremap -c "CREATE EXTENSION IF NOT EXISTS postgis;" + + - name: Initialize database schema with FastAPI models + run: | + docker run --rm --network gwtm_default \ + -e DB_USER=treasuremap \ + -e DB_PWD=treasuremap \ + -e DB_NAME=treasuremap \ + -e DB_HOST=gwtm_db \ + -e DB_PORT=5432 \ + gwtm_fastapi:latest \ + python -c "from server.db.init_db import create_database_tables; create_database_tables()" + + - name: Load test data + run: | + docker compose exec -T db psql -U treasuremap -d treasuremap < tests/test-data.sql + + - name: Start FastAPI server + run: | + docker run -d --name fastapi-server \ + --network gwtm_default \ + -e DB_USER=treasuremap \ + -e DB_PWD=treasuremap \ + -e DB_NAME=treasuremap \ + -e DB_HOST=gwtm_db \ + -e DB_PORT=5432 \ + -p 8000:8000 \ + gwtm_fastapi:latest + + - name: Wait for FastAPI server to be ready + run: | + timeout 120 bash -c 'until curl -f http://localhost:8000/health; do echo "Waiting for FastAPI..."; sleep 5; done' + + - name: Run FastAPI tests + run: | + python -m pytest tests/fastapi/ -v --disable-warnings + env: + API_BASE_URL: http://localhost:8000 + DB_HOST: localhost + DB_PORT: 5432 + DB_NAME: treasuremap + DB_USER: treasuremap + DB_PWD: treasuremap + + - name: Cleanup + if: always() + run: | + docker stop fastapi-server || true + docker rm fastapi-server || true + docker compose down -v \ No newline at end of file diff --git a/.gitignore b/.gitignore index dec22deb..35f5ed6a 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ environment_variables.sh envars.sh test *.DS_Store -deploy/ \ No newline at end of file +deploy/ +*venv/ diff --git a/README.md b/README.md index f0647e12..ed7fe595 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,23 @@ # GW Treasure Map Website environment +## Quick Start + +**For the modern FastAPI backend (recommended):** +The FastAPI application requires database and cache services. Use Skaffold for the complete development environment: +```bash +cd gwtm-helm +skaffold dev # Starts full stack including FastAPI, database, and cache +``` +FastAPI will be available at http://localhost:8000 with API docs at http://localhost:8000/docs + +See the [FastAPI README](server/README.md) for detailed setup instructions and testing. + +**For the legacy Flask application:** +```bash +python gwtm.wsgi # Development server on :5000 +``` + ### Step-by-step installation ### Python: diff --git a/gwtm-helm/fastapi-README.md b/gwtm-helm/fastapi-README.md new file mode 100644 index 00000000..2b03d3ac --- /dev/null +++ b/gwtm-helm/fastapi-README.md @@ -0,0 +1,81 @@ +{{- /* This file is for documentation only and should not be processed as a template */ -}} +# FastAPI Helm Templates + +These templates are used to deploy the FastAPI backend service of the GWTM application. + +## Templates + +- `deployment.yaml`: Defines the Kubernetes Deployment for the FastAPI service. +- `service.yaml`: Defines the Kubernetes Service for the FastAPI service. +- `configmap.yaml`: Contains configuration data for the FastAPI service. + +## Configuration + +Configuration for the FastAPI service is defined in the `values.yaml` file under the `fastapi` key: + +```yaml +fastapi: + name: fastapi-backend + replicas: 2 + image: + repository: gwtm-fastapi + tag: latest + pullPolicy: IfNotPresent + service: + port: 8000 + targetPort: 8000 + readinessProbe: + enabled: true + path: /docs + initialDelaySeconds: 10 + periodSeconds: 5 + livenessProbe: + enabled: true + path: /docs + initialDelaySeconds: 30 + periodSeconds: 15 + resources: + limits: + cpu: 500m + memory: 512Mi + requests: + cpu: 200m + memory: 256Mi +``` + +## Ingress Configuration + +The FastAPI service is exposed through the following routes in the Ingress: + +- `/api/v1/*`: API endpoints +- `/docs`: Swagger UI documentation +- `/redoc`: ReDoc documentation +- `/openapi.json`: OpenAPI schema +- `/health`: Health check endpoint + +## Environment Variables + +The FastAPI service uses environment variables from the `secrets.yaml` template, which includes: + +- Database credentials +- Mail configuration +- AWS/Azure credentials +- Other application-specific settings + +## Usage + +To deploy the FastAPI service, include these templates in your Helm installation: + +```bash +helm install gwtm ./gwtm-helm +``` + +You can customize the deployment by overriding values: + +```bash +helm install gwtm ./gwtm-helm --set fastapi.replicas=3 --set fastapi.image.tag=v1.0.0 +``` + +## Health Checks + +The FastAPI service includes readiness and liveness probes that check the `/docs` endpoint to verify that the service is running correctly. \ No newline at end of file diff --git a/gwtm-helm/restore-db b/gwtm-helm/restore-db index a75104d8..7f3366ee 100755 --- a/gwtm-helm/restore-db +++ b/gwtm-helm/restore-db @@ -20,6 +20,7 @@ kubectl cp $DUMP_FILE $POSTGRES_POD:/tmp/dump.sql -n $NAMESPACE # Execute restore echo "Restoring database..." + kubectl -n $NAMESPACE exec -it $POSTGRES_POD -- bash -c "PGPASSWORD=$DB_PASSWORD psql -U $DB_USER -f /tmp/dump.sql -a" echo "Restore completed!" diff --git a/gwtm-helm/skaffold.yaml b/gwtm-helm/skaffold.yaml index e3e3c2c6..caaeb57a 100644 --- a/gwtm-helm/skaffold.yaml +++ b/gwtm-helm/skaffold.yaml @@ -1,59 +1,102 @@ -apiVersion: skaffold/v2beta28 +apiVersion: skaffold/v4beta13 kind: Config metadata: name: gwtm build: + artifacts: + - image: gwtm + context: .. + docker: {} + - image: gwtm-fastapi + context: .. + docker: + dockerfile: server/Dockerfile + sync: + manual: + - src: "server/routes/**/*.py" + dest: "/app/server/routes/" + - src: "server/db/**/*.py" + dest: "/app/server/db/" + - src: "server/schemas/**/*.py" + dest: "/app/server/schemas/" + - src: "server/services/**/*.py" + dest: "/app/server/services/" + - src: "server/utils/**/*.py" + dest: "/app/server/utils/" + - src: "server/auth/**/*.py" + dest: "/app/server/auth/" + - src: "server/core/**/*.py" + dest: "/app/server/core/" + - src: "server/main.py" + dest: "/app/server/" + - src: "server/config.py" + dest: "/app/server/" local: push: false - artifacts: - - image: gwtm - context: .. - docker: - dockerfile: Dockerfile +manifests: + helm: + releases: + - name: gwtm + chartPath: . + valuesFiles: + - values-dev.yaml + setValues: + backend.image.repository: gwtm + backend.image.tag: latest + backend.livenessProbe.enabled: "true" + backend.readinessProbe.enabled: "true" + cache.livenessProbe.enabled: "true" + cache.persistence.enabled: "false" + cache.readinessProbe.enabled: "true" + database.initScripts.enabled: "true" + database.livenessProbe.enabled: "true" + database.persistence.enabled: "false" + database.readinessProbe.enabled: "true" + global.createNamespace: "true" + global.namespace: gwtm + createNamespace: true + wait: true + upgradeOnChange: true deploy: helm: releases: - - name: gwtm - chartPath: . - createNamespace: true - valuesFiles: - - values-dev.yaml - setValues: - global.namespace: gwtm - global.createNamespace: true - - # Database values - database.initScripts.enabled: true - database.livenessProbe.enabled: true - database.readinessProbe.enabled: true - database.persistence.enabled: false - - # Cache values - cache.livenessProbe.enabled: true - cache.readinessProbe.enabled: true - cache.persistence.enabled: false - - # Backend values - backend.readinessProbe.enabled: true - backend.livenessProbe.enabled: true - - # Image settings - backend.image.repository: gwtm - backend.image.tag: latest - wait: true - upgradeOnChange: true + - name: gwtm + chartPath: . + valuesFiles: + - values-dev.yaml + setValues: + backend.image.repository: gwtm + backend.image.tag: latest + backend.livenessProbe.enabled: "true" + backend.readinessProbe.enabled: "true" + cache.livenessProbe.enabled: "true" + cache.persistence.enabled: "false" + cache.readinessProbe.enabled: "true" + database.initScripts.enabled: "true" + database.livenessProbe.enabled: "true" + database.persistence.enabled: "false" + database.readinessProbe.enabled: "true" + global.createNamespace: "true" + global.namespace: gwtm + createNamespace: true + wait: true + upgradeOnChange: true portForward: -# Direct backend access on local port 8080 -- resourceType: service - resourceName: flask-backend - namespace: gwtm - port: 8080 - localPort: 8080 - address: 0.0.0.0 -# Frontend on local port 8081 -- resourceType: service - resourceName: frontend - namespace: gwtm - port: 80 - localPort: 8081 - address: 0.0.0.0 + - resourceType: service + resourceName: flask-backend + namespace: gwtm + port: 8080 + address: 0.0.0.0 + localPort: 8080 + - resourceType: service + resourceName: frontend + namespace: gwtm + port: 80 + address: 0.0.0.0 + localPort: 8081 + - resourceType: service + resourceName: fastapi-backend + namespace: gwtm + port: 8000 + address: 0.0.0.0 + localPort: 8000 diff --git a/gwtm-helm/start-dev.sh b/gwtm-helm/start-dev.sh new file mode 100755 index 00000000..bd73c689 --- /dev/null +++ b/gwtm-helm/start-dev.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Remove namespace if it exists +kubectl delete namespace gwtm + +# Create namespace if it doesn't exist +#kubectl create namespace gwtm 2>/dev/null || true + +# Clean up any existing deployments +kubectl delete deployment -n gwtm --all 2>/dev/null || true + +# Run skaffold with simple options +echo "Starting skaffold development environment..." +skaffold dev diff --git a/gwtm-helm/templates/backend/create-database-tables.yaml b/gwtm-helm/templates/backend/create-database-tables.yaml new file mode 100644 index 00000000..e8be8a9b --- /dev/null +++ b/gwtm-helm/templates/backend/create-database-tables.yaml @@ -0,0 +1,72 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: database-setup + namespace: {{ .Values.global.namespace }} + labels: + {{- include "gwtm.labels" . | nindent 4 }} +spec: + template: + spec: + initContainers: + - name: wait-for-db + image: postgres:14-alpine + command: ['sh', '-c', 'echo "Waiting for database to be ready..."; + until pg_isready -h ${DB_HOST} -p ${DB_PORT} -U ${DB_USER} -d ${DB_NAME}; do + echo "$(date) - Waiting for database at ${DB_HOST}:${DB_PORT}..."; + sleep 5; + done; + echo "Database is ready!"'] + env: + - name: DB_USER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-user + - name: DB_PASSWORD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + - name: DB_NAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-name + - name: DB_HOST + value: {{ .Values.database.name }} + - name: DB_PORT + value: "{{ .Values.database.service.port }}" + - name: PGPASSWORD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + containers: + - name: db-setup + image: "{{ .Values.backend.image.repository }}:{{ .Values.backend.image.tag }}" + command: ["python", "-c"] + args: + - | + from src.models import create_database_tables; create_database_tables() + env: + - name: DB_USER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-user + - name: DB_PASSWORD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + - name: DB_NAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-name + - name: DB_HOST + value: {{ .Values.database.name }} + - name: DB_PORT + value: "{{ .Values.database.service.port }}" + restartPolicy: OnFailure diff --git a/gwtm-helm/templates/fastapi/configmap.yaml b/gwtm-helm/templates/fastapi/configmap.yaml new file mode 100644 index 00000000..c8804cc3 --- /dev/null +++ b/gwtm-helm/templates/fastapi/configmap.yaml @@ -0,0 +1,15 @@ +# templates/fastapi/configmap.yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ .Values.fastapi.name }}-config + namespace: {{ .Values.global.namespace }} + labels: + {{- include "gwtm.labels" . | nindent 4 }} + app: {{ .Values.fastapi.name }} +data: + APP_NAME: "GWTM FastAPI" + DEBUG: "{{ .Values.global.environment | eq "development" | ternary "True" "False" }}" + CORS_ORIGINS: '["https://{{ .Values.ingress.host }}", "http://{{ .Values.ingress.host }}", "http://localhost:3000", "http://localhost:5173"]' + BASE_URL: "https://{{ .Values.ingress.host }}" + # Add any additional configuration for FastAPI here \ No newline at end of file diff --git a/gwtm-helm/templates/fastapi/deployment.yaml b/gwtm-helm/templates/fastapi/deployment.yaml new file mode 100644 index 00000000..fe399174 --- /dev/null +++ b/gwtm-helm/templates/fastapi/deployment.yaml @@ -0,0 +1,203 @@ +# templates/fastapi/deployment.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Values.fastapi.name }} + namespace: {{ .Values.global.namespace }} + labels: + {{- include "gwtm.labels" . | nindent 4 }} + app: {{ .Values.fastapi.name }} +spec: + replicas: {{ .Values.fastapi.replicas }} + selector: + matchLabels: + app: {{ .Values.fastapi.name }} + template: + metadata: + labels: + app: {{ .Values.fastapi.name }} + spec: + initContainers: + - name: wait-for-db + image: postgres:14-alpine + command: ['sh', '-c', 'echo "Testing connection to ${DB_HOST}:${DB_PORT}..."; + export PGPASSWORD=${DB_PWD}; + until pg_isready -h ${DB_HOST} -p ${DB_PORT} -U ${DB_USER} -d ${DB_NAME}; do + echo "$(date) - Waiting for database at ${DB_HOST}:${DB_PORT}..."; + echo "Trying to ping database..."; + nc -v -z -w2 ${DB_HOST} ${DB_PORT} || echo "Network connection failed"; + sleep 5; + done; + echo "Database is ready!"'] + env: + - name: DB_USER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-user + - name: DB_PWD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + - name: DB_NAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-name + - name: DB_HOST + value: {{ .Values.database.name }} + - name: DB_PORT + value: "{{ .Values.database.service.port }}" + - name: PGPASSWORD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + containers: + - name: fastapi-app + image: "{{ .Values.fastapi.image.repository }}:{{ .Values.fastapi.image.tag }}" + imagePullPolicy: {{ .Values.fastapi.image.pullPolicy }} + command: ["/bin/bash", "-c"] + args: + - | + echo "Starting FastAPI application..." + echo "Database settings: DB_HOST=$DB_HOST, DB_PORT=$DB_PORT, DB_USER=$DB_USER, DB_NAME=$DB_NAME" + echo "Redis URL: $REDIS_URL" + cd /app + uvicorn server.main:app --host 0.0.0.0 --port {{ .Values.fastapi.service.targetPort }} --workers {{ .Values.fastapi.workers }} + ports: + - containerPort: {{ .Values.fastapi.service.targetPort }} + env: + - name: DB_USER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-user + - name: DB_PWD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-password + - name: DB_NAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: db-name + - name: DB_HOST + value: {{ .Values.database.name }} + - name: DB_PORT + value: "{{ .Values.database.service.port }}" + - name: REDIS_URL + value: "redis://{{ .Values.cache.name }}:{{ .Values.cache.service.port }}/0" + - name: MAIL_PASSWORD + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: mail-password + - name: MAIL_USERNAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: MAIL_USERNAME + - name: MAIL_DEFAULT_SENDER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: MAIL_DEFAULT_SENDER + - name: MAIL_SERVER + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: MAIL_SERVER + - name: MAIL_PORT + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: MAIL_PORT + - name: RECAPTCHA_PUBLIC_KEY + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: recaptcha-public-key + - name: RECAPTCHA_PRIVATE_KEY + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: recaptcha-private-key + - name: ZENODO_ACCESS_KEY + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: zenodo-access-key + - name: AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: aws-access-key-id + - name: AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: aws-secret-access-key + - name: AWS_DEFAULT_REGION + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: AWS_DEFAULT_REGION + - name: AWS_BUCKET + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: AWS_BUCKET + - name: AZURE_ACCOUNT_NAME + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: AZURE_ACCOUNT_NAME + - name: AZURE_ACCOUNT_KEY + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: AZURE_ACCOUNT_KEY + - name: STORAGE_BUCKET_SOURCE + valueFrom: + secretKeyRef: + name: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + key: STORAGE_BUCKET_SOURCE + - name: DEBUG + value: "{{ .Values.global.environment | eq "development" | ternary "True" "False" }}" + {{- if .Values.fastapi.extraEnv }} + {{- toYaml .Values.fastapi.extraEnv | nindent 8 }} + {{- end }} + {{- if .Values.fastapi.readinessProbe.enabled }} + readinessProbe: + httpGet: + path: {{ .Values.fastapi.readinessProbe.path }} + port: {{ .Values.fastapi.service.targetPort }} + initialDelaySeconds: {{ .Values.fastapi.readinessProbe.initialDelaySeconds }} + periodSeconds: {{ .Values.fastapi.readinessProbe.periodSeconds }} + timeoutSeconds: 5 + successThreshold: 1 + failureThreshold: 3 + {{- end }} + {{- if .Values.fastapi.livenessProbe.enabled }} + livenessProbe: + httpGet: + path: {{ .Values.fastapi.livenessProbe.path }} + port: {{ .Values.fastapi.service.targetPort }} + initialDelaySeconds: {{ .Values.fastapi.livenessProbe.initialDelaySeconds }} + periodSeconds: {{ .Values.fastapi.livenessProbe.periodSeconds }} + timeoutSeconds: 5 + successThreshold: 1 + failureThreshold: 3 + {{- end }} + resources: + {{- toYaml .Values.fastapi.resources | nindent 12 }} + volumes: + - name: db-user + secret: + secretName: {{ if .Values.global.useGeneratedSecrets }}{{ .Release.Name }}-secrets{{ else }}gwtm-secrets{{ end }} + items: + - key: db-user + path: db-user \ No newline at end of file diff --git a/gwtm-helm/templates/fastapi/service.yaml b/gwtm-helm/templates/fastapi/service.yaml new file mode 100644 index 00000000..10047134 --- /dev/null +++ b/gwtm-helm/templates/fastapi/service.yaml @@ -0,0 +1,15 @@ +# templates/fastapi/service.yaml +apiVersion: v1 +kind: Service +metadata: + name: {{ .Values.fastapi.name }} + namespace: {{ .Values.global.namespace }} + labels: + {{- include "gwtm.labels" . | nindent 4 }} + app: {{ .Values.fastapi.name }} +spec: + selector: + app: {{ .Values.fastapi.name }} + ports: + - port: {{ .Values.fastapi.service.port }} + targetPort: {{ .Values.fastapi.service.targetPort }} \ No newline at end of file diff --git a/gwtm-helm/templates/ingress.yaml b/gwtm-helm/templates/ingress.yaml index 02fea006..4b5ad10b 100644 --- a/gwtm-helm/templates/ingress.yaml +++ b/gwtm-helm/templates/ingress.yaml @@ -27,6 +27,41 @@ spec: - host: {{ .Values.ingress.host | quote }} http: paths: + - path: /api/v1 + pathType: Prefix + backend: + service: + name: {{ .Values.fastapi.name }} + port: + number: {{ .Values.fastapi.service.port }} + - path: /docs + pathType: Prefix + backend: + service: + name: {{ .Values.fastapi.name }} + port: + number: {{ .Values.fastapi.service.port }} + - path: /redoc + pathType: Prefix + backend: + service: + name: {{ .Values.fastapi.name }} + port: + number: {{ .Values.fastapi.service.port }} + - path: /openapi.json + pathType: Exact + backend: + service: + name: {{ .Values.fastapi.name }} + port: + number: {{ .Values.fastapi.service.port }} + - path: /health + pathType: Exact + backend: + service: + name: {{ .Values.fastapi.name }} + port: + number: {{ .Values.fastapi.service.port }} - path: / pathType: Prefix backend: diff --git a/gwtm-helm/values-dev.yaml b/gwtm-helm/values-dev.yaml index 69b1ec55..cd722bf5 100644 --- a/gwtm-helm/values-dev.yaml +++ b/gwtm-helm/values-dev.yaml @@ -52,9 +52,25 @@ backend: readinessProbe: enabled: false +fastapi: + replicas: 1 + resources: + limits: + cpu: 500m + memory: 500Mi + requests: + cpu: 200m + memory: 300Mi + readinessProbe: + enabled: false + workers: 1 + frontend: replicas: 1 database: persistence: + enabled: false size: 1Gi + initScripts: + enabled: true diff --git a/gwtm-helm/values.yaml b/gwtm-helm/values.yaml index 96fb6525..1946412f 100644 --- a/gwtm-helm/values.yaml +++ b/gwtm-helm/values.yaml @@ -32,6 +32,35 @@ backend: cpu: 200m memory: 256Mi +fastapi: + name: fastapi-backend + replicas: 2 + image: + repository: gwtm-fastapi + tag: latest + pullPolicy: IfNotPresent + service: + port: 8000 + targetPort: 8000 + readinessProbe: + enabled: false + path: /docs + initialDelaySeconds: 10 + periodSeconds: 5 + livenessProbe: + enabled: false + path: /docs + initialDelaySeconds: 30 + periodSeconds: 15 + resources: + limits: + cpu: 500m + memory: 512Mi + requests: + cpu: 200m + memory: 256Mi + workers: 4 + frontend: name: frontend replicas: 2 diff --git a/requirements.txt b/requirements.txt index b9a3404a..3c1aae32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ Flask-Login==0.6.3 Flask-Mail==0.10.0 Flask-SQLAlchemy==2.5.1 Flask-WTF +flask-caching GeoAlchemy2>=0.13.0 healpy pyjwt diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 00000000..5be46c62 --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3.11-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + postgresql-client \ + libpq-dev \ + gcc \ + python3-dev \ + libgeos-dev \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH="/app:${PYTHONPATH}" + +# Copy requirements and install dependencies +COPY server/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY server/ /app/server/ + +# Set the working directory to the server directory +WORKDIR /app + +# Command to run the application +CMD ["uvicorn", "server.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/server/README.md b/server/README.md new file mode 100644 index 00000000..00f3c643 --- /dev/null +++ b/server/README.md @@ -0,0 +1,351 @@ +# GWTM FastAPI Backend + +This directory contains the FastAPI implementation of the GWTM backend API. + +## Overview + +The FastAPI implementation provides a modern, high-performance REST API for the Gravitational-Wave Treasure Map application. It is designed to be a drop-in replacement for the Flask-based API with improved performance, better type checking, automatic documentation, and a more modular structure. + +## Directory Structure + +``` +server/ +├── auth/ # Authentication utilities +│ └── auth.py # JWT token handling and user authentication +├── core/ # Core functionality and shared components +│ └── enums/ # Enumeration types (bandpass, depth_unit, etc.) +├── db/ # Database models and configuration +│ ├── models/ # SQLAlchemy ORM models +│ ├── config.py # Database configuration +│ ├── database.py # Database connection and session management +│ └── utils.py # Database utility functions +├── routes/ # API route definitions +│ ├── pointing/ # Pointing management routes +│ │ ├── router.py # Consolidated pointing router +│ │ ├── create_pointings.py # POST /pointings endpoint +│ │ ├── get_pointings.py # GET /pointings endpoint +│ │ ├── update_pointings.py # POST /update_pointings endpoint +│ │ ├── cancel_all.py # POST /cancel_all endpoint +│ │ ├── request_doi.py # POST /request_doi endpoint +│ │ └── test_refactoring.py # GET /test_refactoring endpoint +│ ├── instrument/ # Instrument management routes +│ │ ├── router.py # Consolidated instrument router +│ │ ├── get_instruments.py # GET /instruments endpoint +│ │ ├── get_footprints.py # GET /footprints endpoint +│ │ ├── create_instrument.py# POST /instruments endpoint +│ │ └── create_footprint.py # POST /footprints endpoint +│ ├── admin/ # Admin-related routes +│ │ ├── router.py # Consolidated admin router +│ │ └── fixdata.py # GET/POST /fixdata endpoint +│ ├── candidate/ # Candidate management routes +│ │ ├── router.py # Consolidated candidate router +│ │ ├── get_candidates.py # GET /candidate endpoint +│ │ ├── create_candidates.py# POST /candidate endpoint +│ │ ├── update_candidate.py # PUT /candidate endpoint +│ │ └── delete_candidates.py# DELETE /candidate endpoint +│ ├── doi/ # DOI request routes +│ │ ├── router.py # Consolidated DOI router +│ │ ├── get_doi_pointings.py# GET /doi_pointings endpoint +│ │ ├── get_author_groups.py# GET /doi_author_groups endpoint +│ │ └── get_authors.py # GET /doi_authors/{group_id} endpoint +│ ├── gw_alert/ # GW alert management routes +│ │ ├── router.py # Consolidated GW alert router +│ │ ├── query_alerts.py # GET /query_alerts endpoint +│ │ ├── post_alert.py # POST /post_alert endpoint +│ │ ├── get_skymap.py # GET /gw_skymap endpoint +│ │ ├── get_contour.py # GET /gw_contour endpoint +│ │ ├── get_grb_moc.py # GET /grb_moc_file endpoint +│ │ └── delete_test_alerts.py # POST /del_test_alerts endpoint +│ ├── gw_galaxy/ # Galaxy catalog management routes +│ │ ├── router.py # Consolidated galaxy router +│ │ ├── get_event_galaxies.py # GET /event_galaxies endpoint +│ │ ├── post_event_galaxies.py # POST /event_galaxies endpoint +│ │ ├── remove_event_galaxies.py # DELETE /remove_event_galaxies endpoint +│ │ └── get_glade.py # GET /glade endpoint +│ ├── icecube/ # IceCube neutrino event routes +│ │ ├── router.py # Consolidated IceCube router +│ │ └── post_icecube_notice.py # POST /post_icecube_notice endpoint +│ ├── event/ # Event candidate management routes +│ │ ├── router.py # Consolidated event router +│ │ ├── utils.py # Utility functions for event routes +│ │ ├── get_candidate_events.py # GET /candidate/event endpoint +│ │ ├── create_candidate_event.py # POST /candidate/event endpoint +│ │ ├── update_candidate_event.py # PUT /candidate/event/{candidate_id} endpoint +│ │ └── delete_candidate_event.py # DELETE /candidate/event/{candidate_id} endpoint +│ └── ui/ # UI-specific endpoints (AJAX helpers) +│ ├── router.py # Consolidated UI router +│ ├── alert_instruments_footprints.py # GET /ajax_alertinstruments_footprints +│ ├── preview_footprint.py # GET /ajax_preview_footprint +│ ├── resend_verification_email.py # POST /ajax_resend_verification_email +│ ├── coverage_calculator.py # POST /ajax_coverage_calculator +│ ├── spectral_range_from_bands.py # GET /ajax_update_spectral_range_from_selected_bands +│ ├── pointing_from_id.py # GET /ajax_pointingfromid +│ ├── grade_calculator.py # POST /ajax_grade_calculator +│ ├── icecube_notice.py # GET /ajax_icecube_notice +│ ├── event_galaxies.py # GET /ajax_event_galaxies +│ ├── scimma_xrt.py # GET /ajax_scimma_xrt +│ ├── candidate_fetch.py # GET /ajax_candidate +│ ├── request_doi.py # GET /ajax_request_doi +│ └── alert_type.py # GET /ajax_alerttype +├── schemas/ # Pydantic schemas for validation +│ ├── candidate.py # Candidate schemas +│ ├── doi.py # DOI schemas +│ ├── glade.py # GLADE catalog schemas +│ ├── gw_alert.py # GW alert schemas +│ ├── gw_galaxy.py # Galaxy schemas +│ ├── icecube.py # IceCube schemas +│ ├── instrument.py# Instrument schemas +│ ├── pointing.py # Pointing schemas +│ └── users.py # User schemas +├── utils/ # Utility functions +│ ├── email.py # Email utilities +│ ├── error_handling.py # Error handling utilities +│ ├── function.py # General utility functions +│ ├── gwtm_io.py # File I/O utilities +│ ├── pointing.py # Pointing validation and creation utilities +│ └── spectral.py # Spectral range calculations and conversions +├── config.py # Application configuration +├── main.py # FastAPI application entry point +├── requirements.txt # Python dependencies +└── Dockerfile # Docker configuration for deployment +``` + +## Development + +### Preferred Development Setup with Skaffold + +The recommended way to run the development server is using Skaffold, which manages the full application stack including database, cache, and all services: + +1. **Prerequisites:** + - [Skaffold](https://skaffold.dev/docs/install/) installed + - [kubectl](https://kubernetes.io/docs/tasks/tools/) configured for local cluster (minikube, kind, or Docker Desktop) + - [Helm](https://helm.sh/docs/intro/install/) installed + +2. **Start the development environment:** + ```bash + cd gwtm-helm + skaffold dev + ``` + + This will: + - Build Docker images for all services + - Deploy the complete stack to your local Kubernetes cluster + - Watch for file changes and automatically rebuild/redeploy + - Forward ports to access services locally + +3. **Access the services:** + - FastAPI server: http://localhost:8000 + - API documentation: http://localhost:8000/docs + - Database: localhost:5432 + - Redis cache: localhost:6379 + +4. **Load test data** (required for running tests): + ```bash + kubectl exec -it deployment/database -- psql -U treasuremap -d treasuremap_dev -f /docker-entrypoint-initdb.d/test-data.sql + ``` + +### Alternative: Local Development + +If you prefer to run only the FastAPI server locally: + +1. **Set up the database:** + - Ensure PostgreSQL is running with PostGIS extension + - Create a database named `treasuremap_dev` + - Load test data: `psql -U treasuremap -d treasuremap_dev -f tests/test-data.sql` + +2. **Create a virtual environment:** + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +3. **Install dependencies:** + ```bash + pip install -r server/requirements.txt + ``` + +4. **Set environment variables:** + Create a `.env` file in the server directory: + ``` + DB_USER=treasuremap + DB_PWD=your_password + DB_NAME=treasuremap_dev + DB_HOST=localhost + DB_PORT=5432 + DEBUG=True + ``` + +5. **Run the development server:** + ```bash + cd server + uvicorn main:app --reload --host 0.0.0.0 --port 8000 + ``` + +## Testing + +The FastAPI application includes comprehensive tests to ensure functionality and compatibility with the Flask implementation. + +### Prerequisites for Testing + +1. **Install test dependencies:** + ```bash + pip install -r tests/requirements.txt + ``` + +2. **Ensure FastAPI server is running:** + - Using Skaffold: `skaffold dev` (FastAPI available at http://localhost:8000) + - Using local setup: `uvicorn main:app --reload` (from server directory) + +3. **Load test data:** Tests automatically load fresh test data before running, but you can also load manually: + ```bash + # Using Skaffold/Kubernetes + gwtm-helm/restore-db tests/test-data.sql + + # Using local database + psql -U treasuremap -d treasuremap_dev -f tests/test-data.sql + ``` + +### Running Tests + +**From the project root directory:** + +```bash +# Run all FastAPI tests +python -m pytest tests/fastapi/ -v --disable-warnings + +# Run specific test modules +python -m pytest tests/fastapi/test_pointing.py -v +python -m pytest tests/fastapi/test_instrument.py -v +python -m pytest tests/fastapi/test_ui.py -v + +# Run with more verbose output +python -m pytest tests/fastapi/ -vv + +# Run with coverage reporting +python -m pytest tests/fastapi/ -v --cov=server +``` + +**Using the test script:** + +```bash +# Run all tests +./tests/run--fastapi-tests.sh + +# Run specific module (e.g., pointing tests) +./tests/run-fastapi:-tests.sh pointing +``` + +### Test Configuration + +Tests are configured in `tests/fastapi/conftest.py` which: +- Automatically waits for the FastAPI server to be ready +- Loads fresh test data before running tests +- Provides common fixtures for API URLs, headers, and test tokens + +### Environment Variables + +Set these environment variables if your setup differs from defaults: + +```bash +export API_BASE_URL="http://localhost:8000" # FastAPI server URL +export DB_HOST="localhost" # Database host +export DB_PORT="5432" # Database port +export DB_NAME="treasuremap" # Database name +export DB_USER="treasuremap" # Database user +export DB_PWD="treasuremap" # Database password +``` + +### Test Categories + +- **`test_admin.py`** - Administrative functions and user management +- **`test_candidate.py`** - Candidate management and CRUD operations +- **`test_doi.py`** - DOI request and author management +- **`test_event.py`** - GW event and alert querying +- **`test_gw_alert.py`** - GW alert management +- **`test_gw_galaxy.py`** - Galaxy catalog management +- **`test_icecube.py`** - IceCube neutrino event integration +- **`test_instrument.py`** - Instrument management and validation +- **`test_pointing.py`** - Pointing CRUD operations and validation +- **`test_ui.py`** - UI-specific endpoints and AJAX helpers + +### Test Data + +The test suite uses predefined data from `tests/test-data.sql` which includes: +- Sample GW alerts with known GraceIDs +- Test users with various permission levels +- Sample instruments and their configurations +- Test pointings and observations +- Galaxy catalog entries + +**Note:** Always ensure test data is loaded before running tests, as many tests depend on specific entries existing in the database. + +## Deployment + +### Production Deployment with Helm + +The recommended deployment method is using the Helm chart which deploys the complete GWTM stack: + +```bash +# Deploy to production +cd gwtm-helm +helm install gwtm . -f values-prod.yaml + +# Deploy to development/staging +helm install gwtm-dev . -f values-dev.yaml +``` + +The Helm chart includes: +- FastAPI backend service +- PostgreSQL database with PostGIS +- Redis cache +- Frontend service +- Ingress configuration +- Persistent volumes for data storage + +For detailed deployment configuration, see `gwtm-helm/README.md`. + +### Development Deployment with Skaffold + +For development environments with automatic rebuilds: + +```bash +cd gwtm-helm +skaffold run # Deploy once +# or +skaffold dev # Deploy with file watching and auto-rebuild +``` + +### Using Docker (Standalone) + +To run just the FastAPI service in a container: + +```bash +# Build the Docker image +docker build -f server/Dockerfile -t gwtm-fastapi . + +# Run with required environment variables +docker run -p 8000:8000 \ + -e DB_HOST=your-postgres-host \ + -e DB_USER=treasuremap \ + -e DB_PWD=your_password \ + -e DB_NAME=treasuremap \ + gwtm-fastapi +``` + +**Note:** The FastAPI service requires a PostgreSQL database with PostGIS extension and Redis cache for full functionality. + +## API Documentation + +The API documentation is automatically generated and available at `/docs` when the application is running. It provides: + +- Interactive API documentation +- Request/response examples +- Schema definitions +- Authentication information + +## Authentication + +The API uses JWT-based authentication. To authenticate: + +1. Send a POST request to `/api/v1/login` with username and password +2. Use the returned token in the `Authorization` header as `Bearer ` for protected endpoints diff --git a/server/auth/__init__.py b/server/auth/__init__.py new file mode 100644 index 00000000..09d381f3 --- /dev/null +++ b/server/auth/__init__.py @@ -0,0 +1 @@ +# Auth package initialization diff --git a/server/auth/auth.py b/server/auth/auth.py new file mode 100644 index 00000000..9a87e296 --- /dev/null +++ b/server/auth/auth.py @@ -0,0 +1,134 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import APIKeyHeader, OAuth2PasswordBearer +from sqlalchemy.orm import Session +from server.db.database import get_db +from server.db.models import UserGroups, Groups +from server.db.models.users import Users +from typing import Optional +from datetime import datetime, timedelta +import jwt + +from server.config import settings + +# Define the API key header +api_key_header = APIKeyHeader(name="api_token", auto_error=False) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """ + Create a JWT access token + + Args: + data: Data to encode in the token + expires_delta: Optional expiration time + + Returns: + JWT token as a string + """ + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta( + minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES + ) + + to_encode.update({"exp": expire}) + + encoded_jwt = jwt.encode( + to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM + ) + + return encoded_jwt + + +def decode_token(token: str) -> dict: + """ + Decode a JWT token + + Args: + token: JWT token + + Returns: + Decoded payload + """ + try: + payload = jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) + return payload + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" + ) + + +def get_current_user( + api_token: str = Depends(api_key_header), db: Session = Depends(get_db) +) -> Optional[Users]: + """ + Validate API token and return the associated user + """ + if not api_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API token is required", + ) + + user = db.query(Users).filter(Users.api_token == api_token).first() + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API token", + ) + + # Log user action + # Implementation of useractions logging will be added later + + return user + + +def verify_admin( + user: Users = Depends(get_current_user), db: Session = Depends(get_db) +) -> Users: + """ + Check if the user belongs to the admin group. + """ + admin_group = db.query(Groups).filter(Groups.name == "admin").first() + if not admin_group: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin group does not exist", + ) + + user_group = ( + db.query(UserGroups) + .filter(UserGroups.userid == user.id, UserGroups.groupid == admin_group.id) + .first() + ) + + if not user_group: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admins can access this endpoint", + ) + + return user + + +def log_user_action( + user: Users, + request_path: str, + method: str, + ip_address: str, + json_data=None, + db: Session = Depends(get_db), +): + """ + Log user actions for auditing + Will be implemented with full UserAction model + """ + # This will be implemented when the UserAction model is fully ported + pass diff --git a/server/config.py b/server/config.py new file mode 100644 index 00000000..b5151b13 --- /dev/null +++ b/server/config.py @@ -0,0 +1,115 @@ +import os +import json +from functools import lru_cache +from typing import List, Optional +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings using Pydantic BaseSettings for environment variable loading.""" + + # Application settings + APP_NAME: str = "GWTM API" + DEBUG: bool = Field(False, env="DEBUG") + + # Database settings + DB_USER: str = Field("treasuremap", env="DB_USER") + DB_PWD: str = Field("", env="DB_PWD") + DB_NAME: str = Field("treasuremap_dev", env="DB_NAME") + DB_HOST: str = Field("localhost", env="DB_HOST") + DB_PORT: str = Field("5432", env="DB_PORT") + + # Email settings + MAIL_USERNAME: str = Field("gwtreasuremap@gmail.com", env="MAIL_USERNAME") + MAIL_DEFAULT_SENDER: str = Field( + "gwtreasuremap@gmail.com", env="MAIL_DEFAULT_SENDER" + ) + MAIL_PASSWORD: str = Field("", env="MAIL_PASSWORD") + MAIL_SERVER: str = Field("", env="MAIL_SERVER") + MAIL_PORT: int = Field(465, env="MAIL_PORT") + MAIL_USE_TLS: bool = Field(False, env="MAIL_USE_TLS") + MAIL_USE_SSL: bool = Field(True, env="MAIL_USE_SSL") + + # Admin settings + ADMINS: str = Field("gwtreasuremap@gmail.com", env="ADMINS") + + # Security settings + SECRET_KEY: str = Field( + default_factory=lambda: os.urandom(16).hex(), env="SECRET_KEY" + ) + JWT_SECRET_KEY: str = Field( + default_factory=lambda: os.urandom(16).hex(), env="JWT_SECRET_KEY" + ) + JWT_ALGORITHM: str = Field("HS256", env="JWT_ALGORITHM") + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field( + 30, env="JWT_ACCESS_TOKEN_EXPIRE_MINUTES" + ) + + # External services + RECAPTCHA_PUBLIC_KEY: str = Field("", env="RECAPTCHA_PUBLIC_KEY") + RECAPTCHA_PRIVATE_KEY: str = Field("", env="RECAPTCHA_PRIVATE_KEY") + ZENODO_ACCESS_KEY: str = Field("", env="ZENODO_ACCESS_KEY") + + # AWS settings + AWS_ACCESS_KEY_ID: str = Field("", env="AWS_ACCESS_KEY_ID") + AWS_SECRET_ACCESS_KEY: str = Field("", env="AWS_SECRET_ACCESS_KEY") + AWS_DEFAULT_REGION: str = Field("us-east-2", env="AWS_DEFAULT_REGION") + AWS_BUCKET: str = Field("gwtreasuremap", env="AWS_BUCKET") + + # Azure settings + AZURE_ACCOUNT_NAME: str = Field("", env="AZURE_ACCOUNT_NAME") + AZURE_ACCOUNT_KEY: str = Field("", env="AZURE_ACCOUNT_KEY") + + # Storage settings + STORAGE_BUCKET_SOURCE: str = Field("s3", env="STORAGE_BUCKET_SOURCE") + + # Development settings + DEVELOPMENT_MODE: bool = Field(False, env="DEVELOPMENT_MODE") + DEVELOPMENT_STORAGE_DIR: str = Field("./dev_storage", env="DEVELOPMENT_STORAGE_DIR") + + # CORS settings + CORS_ORIGINS: List[str] = ["*"] + CORS_METHODS: List[str] = ["*"] + CORS_HEADERS: List[str] = ["*"] + + @field_validator("CORS_ORIGINS", "CORS_METHODS", "CORS_HEADERS", mode="before") + @classmethod + def parse_cors_list(cls, v): + """Parse CORS settings from JSON string or return as-is if already a list.""" + if isinstance(v, str): + try: + # Try to parse as JSON array + return json.loads(v) + except json.JSONDecodeError: + # If JSON parsing fails, treat as comma-separated string + return [item.strip() for item in v.split(",") if item.strip()] + return v + + # Database URL + @property + def SQLALCHEMY_DATABASE_URI(self) -> str: + """Generate the database URI from component settings.""" + return f"postgresql://{self.DB_USER}:{self.DB_PWD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" + + # Admin emails + @property + def ADMIN_EMAILS(self) -> list: + """Convert comma-separated ADMINS string to a list.""" + return [email.strip() for email in self.ADMINS.split(",")] + + class Config: + case_sensitive = True + + +@lru_cache() +def get_settings() -> Settings: + """ + Get cached settings singleton. + Uses lru_cache to avoid reading env variables on every call. + """ + return Settings() + + +# Export the settings instance for easy importing +settings = get_settings() diff --git a/server/core/enums/__init__.py b/server/core/enums/__init__.py new file mode 100644 index 00000000..a07347fe --- /dev/null +++ b/server/core/enums/__init__.py @@ -0,0 +1,10 @@ +# server/core/enums/__init__.py + +from .bandpass import Bandpass +from .gwgalaxyscoretype import GwGalaxyScoreType +from .instrumenttype import InstrumentType +from .energyunits import EnergyUnits +from .wavelengthunits import WavelengthUnits +from .pointingstatus import PointingStatus +from .frequencyunits import FrequencyUnits +from .depthunit import DepthUnit diff --git a/server/core/enums/bandpass.py b/server/core/enums/bandpass.py new file mode 100644 index 00000000..53709d3c --- /dev/null +++ b/server/core/enums/bandpass.py @@ -0,0 +1,37 @@ +from enum import IntEnum + + +class Bandpass(IntEnum): + """Enumeration for bandpasses.""" + + U = 1 + B = 2 + V = 3 + R = 4 + I = 5 # noqa: E741 + J = 6 + H = 7 + K = 8 + u = 9 + g = 10 + r = 11 + i = 12 + z = 13 + UVW1 = 14 + UVW2 = 15 + UVM2 = 16 + XRT = 17 + clear = 18 + open = 19 + UHF = 20 + VHF = 21 + L = 22 + S = 23 + C = 24 + X = 25 + other = 26 + TESS = 27 + BAT = 28 + HESS = 29 + WISEL = 30 + q = 31 diff --git a/server/core/enums/depthunit.py b/server/core/enums/depthunit.py new file mode 100644 index 00000000..e61b99a3 --- /dev/null +++ b/server/core/enums/depthunit.py @@ -0,0 +1,15 @@ +from enum import IntEnum + + +class DepthUnit(IntEnum): + """Enumeration for depth units.""" + + ab_mag = 1 + vega_mag = 2 + flux_erg = 3 + flux_jy = 4 + + def __str__(self) -> str: + """Return a formatted string representation of the depth unit.""" + split_name = str(self.name).split("_") + return str.upper(split_name[0]) + " " + split_name[1] diff --git a/server/core/enums/energyunits.py b/server/core/enums/energyunits.py new file mode 100644 index 00000000..ac2bf7dd --- /dev/null +++ b/server/core/enums/energyunits.py @@ -0,0 +1,25 @@ +from enum import IntEnum + + +class EnergyUnits(IntEnum): + """Enumeration for energy units.""" + + eV = 1 + keV = 2 + MeV = 3 + GeV = 4 + TeV = 5 + + @staticmethod + def get_scale(unit): + """Return the scale factor for the given energy unit.""" + if unit == EnergyUnits.eV: + return 1.0 + if unit == EnergyUnits.keV: + return 1000.0 + if unit == EnergyUnits.MeV: + return 1000000.0 + if unit == EnergyUnits.GeV: + return 1000000000.0 + if unit == EnergyUnits.TeV: + return 1000000000000.0 diff --git a/server/core/enums/frequencyunits.py b/server/core/enums/frequencyunits.py new file mode 100644 index 00000000..fff01559 --- /dev/null +++ b/server/core/enums/frequencyunits.py @@ -0,0 +1,25 @@ +from enum import IntEnum + + +class FrequencyUnits(IntEnum): + """Enumeration for frequency units.""" + + Hz = 1 + kHz = 2 + GHz = 3 + MHz = 4 + THz = 5 + + @staticmethod + def get_scale(unit): + """Return the scale factor for the given frequency unit.""" + if unit == FrequencyUnits.Hz: + return 1.0 + if unit == FrequencyUnits.kHz: + return 1000.0 + if unit == FrequencyUnits.MHz: + return 1000000.0 + if unit == FrequencyUnits.GHz: + return 1000000000.0 + if unit == FrequencyUnits.THz: + return 1000000000000.0 diff --git a/server/core/enums/gwgalaxyscoretype.py b/server/core/enums/gwgalaxyscoretype.py new file mode 100644 index 00000000..e828b574 --- /dev/null +++ b/server/core/enums/gwgalaxyscoretype.py @@ -0,0 +1,7 @@ +from enum import IntEnum + + +class GwGalaxyScoreType(IntEnum): + """Enumeration for GW galaxy score types.""" + + default = 1 diff --git a/server/core/enums/instrumenttype.py b/server/core/enums/instrumenttype.py new file mode 100644 index 00000000..1cd16dd9 --- /dev/null +++ b/server/core/enums/instrumenttype.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class InstrumentType(IntEnum): + """Enumeration for instrument types.""" + + photometric = 1 + spectroscopic = 2 diff --git a/server/core/enums/pointingstatus.py b/server/core/enums/pointingstatus.py new file mode 100644 index 00000000..51b83ad6 --- /dev/null +++ b/server/core/enums/pointingstatus.py @@ -0,0 +1,9 @@ +from enum import IntEnum + + +class PointingStatus(IntEnum): + """Enumeration for pointing statuses.""" + + planned = 1 + completed = 2 + cancelled = 3 diff --git a/server/core/enums/wavelengthunits.py b/server/core/enums/wavelengthunits.py new file mode 100644 index 00000000..9ee07571 --- /dev/null +++ b/server/core/enums/wavelengthunits.py @@ -0,0 +1,19 @@ +from enum import IntEnum + + +class WavelengthUnits(IntEnum): + """Enumeration for wavelength units.""" + + nanometer = 1 + angstrom = 2 + micron = 3 + + @staticmethod + def get_scale(unit): + """Return the scale factor for the given wavelength unit.""" + if unit == WavelengthUnits.nanometer: + return 10.0 + if unit == WavelengthUnits.angstrom: + return 1.0 + if unit == WavelengthUnits.micron: + return 10000.0 diff --git a/server/db/config.py b/server/db/config.py new file mode 100644 index 00000000..6eae5ea1 --- /dev/null +++ b/server/db/config.py @@ -0,0 +1,4 @@ +from server.config import settings + +# Database configuration from central settings +DATABASE_URL = settings.SQLALCHEMY_DATABASE_URI diff --git a/server/db/database.py b/server/db/database.py new file mode 100644 index 00000000..b32d94c6 --- /dev/null +++ b/server/db/database.py @@ -0,0 +1,17 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, declarative_base +from .config import DATABASE_URL + +# Initialize SQLAlchemy engine and session +engine = create_engine(DATABASE_URL, echo=True) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +Base = declarative_base() + + +# Database dependency for FastAPI +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/server/db/init_db.py b/server/db/init_db.py new file mode 100644 index 00000000..fa20f8e3 --- /dev/null +++ b/server/db/init_db.py @@ -0,0 +1,57 @@ +"""FastAPI database initialization script.""" + +import os +from sqlalchemy import create_engine, text +from server.db.database import Base, engine as default_engine +from server.db.models import * # Import all models to register them + + +def create_database_tables(): + """Create database tables using FastAPI models - exactly matching Flask setup.""" + # Use environment variables to override the default database connection + db_user = os.environ.get("DB_USER") + db_pwd = os.environ.get("DB_PWD") + db_name = os.environ.get("DB_NAME") + db_host = os.environ.get("DB_HOST") + db_port = os.environ.get("DB_PORT") + + # If environment variables are provided, create a custom engine + if all([db_user, db_pwd, db_name, db_host, db_port]): + database_url = f"postgresql://{db_user}:{db_pwd}@{db_host}:{db_port}/{db_name}" + engine = create_engine(database_url) + else: + # Use the default engine from FastAPI database configuration + engine = default_engine + + # PostGIS setup - work with existing PostGIS extension + with engine.connect() as conn: + # PostGIS is already installed, just ensure we can use geography types + # Set search path to include both public and postgis schemas + conn.execute(text("SET search_path TO public, postgis;")) + conn.commit() + + # Use an engine with PostGIS search path for table creation + from sqlalchemy.pool import StaticPool + + engine_with_postgis = create_engine( + database_url, + connect_args={"options": "-csearch_path=public,postgis"}, + poolclass=StaticPool, + ) + + # Create all tables using FastAPI models (equivalent to Flask's db.create_all()) + Base.metadata.create_all(bind=engine_with_postgis) + + # Create additional indexes (exactly matching Flask setup) + with engine.connect() as conn: + conn.execute( + text("CREATE INDEX idx_pointing_status_id ON public.pointing(status, id);") + ) + conn.commit() + + print("FastAPI database schema created successfully") + print(f"Created tables: {list(Base.metadata.tables.keys())}") + + +if __name__ == "__main__": + create_database_tables() diff --git a/server/db/models/__init__.py b/server/db/models/__init__.py new file mode 100644 index 00000000..c43b7312 --- /dev/null +++ b/server/db/models/__init__.py @@ -0,0 +1,10 @@ +from .users import Users, UserGroups, Groups, UserActions +from .instrument import Instrument, FootprintCCD +from .pointing import Pointing +from .pointing_event import PointingEvent +from .gw_alert import GWAlert +from .gw_galaxy import GWGalaxy, EventGalaxy, GWGalaxyScore, GWGalaxyList, GWGalaxyEntry +from .glade import Glade2P3 +from .icecube import IceCubeNotice, IceCubeNoticeCoincEvent +from .candidate import GWCandidate +from .doi_author import DOIAuthorGroup, DOIAuthor diff --git a/server/db/models/candidate.py b/server/db/models/candidate.py new file mode 100644 index 00000000..6b7f2a22 --- /dev/null +++ b/server/db/models/candidate.py @@ -0,0 +1,69 @@ +from sqlalchemy import Column, Integer, Float, String, DateTime, ForeignKey, Enum +from server.core.enums.depthunit import DepthUnit as depth_unit_enum +from geoalchemy2 import Geography +from sqlalchemy.ext.hybrid import hybrid_property +from ..database import Base +import shapely.wkb +from datetime import datetime +from typing import Dict, Any, Optional, List + + +class ValidationResult: + """Helper class for validation results""" + + def __init__(self): + self.valid = True + self.errors = [] + self.warnings = [] + + +class GWCandidate(Base): + __tablename__ = "gw_candidate" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + graceid = Column(String(50), nullable=False) + submitterid = Column(Integer, nullable=False) + candidate_name = Column(String(100), nullable=False) + datecreated = Column(DateTime, default=datetime.now) + tns_name = Column(String(100)) + tns_url = Column(String(500)) + position = Column(Geography("POINT", srid=4326), nullable=False) + discovery_date = Column(DateTime) + discovery_magnitude = Column(Float) + magnitude_central_wave = Column(Float) + magnitude_bandwidth = Column(Float) + magnitude_unit = Column(Enum(depth_unit_enum, name="depthunit"), nullable=False) + magnitude_bandpass = Column(String(50)) + associated_galaxy = Column(String(100)) + associated_galaxy_redshift = Column(Float) + associated_galaxy_distance = Column(Float) + + @hybrid_property + def ra(self) -> Optional[float]: + """Get RA coordinate from position""" + try: + position_geom = shapely.wkb.loads(bytes(self.position.data)) + coords = str(position_geom).replace("POINT (", "").replace(")", "").split() + return float(coords[0]) + except (AttributeError, Exception): + return None + + @hybrid_property + def dec(self) -> Optional[float]: + """Get Dec coordinate from position""" + try: + position_geom = shapely.wkb.loads(bytes(self.position.data)) + coords = str(position_geom).replace("POINT (", "").replace(")", "").split() + return float(coords[1]) + except (AttributeError, Exception): + return None + + @hybrid_property + def position_wkt(self) -> Optional[str]: + """Get position as WKT string""" + try: + position_geom = shapely.wkb.loads(bytes(self.position.data)) + return str(position_geom) + except (AttributeError, Exception): + return None diff --git a/server/db/models/doi_author.py b/server/db/models/doi_author.py new file mode 100644 index 00000000..366cf7a6 --- /dev/null +++ b/server/db/models/doi_author.py @@ -0,0 +1,134 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey +from sqlalchemy.orm import Session +from typing import List, Dict, Tuple, Optional, Any, Union + + +from ..database import Base +from server.utils.function import isInt + + +class DOIAuthorGroup(Base): + """ + DOI Author Group model. + Represents a group of authors for DOI creation. + """ + + __tablename__ = "doi_author_group" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + userid = Column(Integer) + name = Column(String) + + +class DOIAuthor(Base): + """ + DOI Author model. + Represents an author associated with a DOI. + """ + + __tablename__ = "doi_author" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + name = Column(String) + affiliation = Column(String) + orcid = Column(String) + gnd = Column(String) + pos_order = Column(Integer) + author_groupid = Column(Integer) + + @staticmethod + def construct_creators( + doi_group_id: Union[int, str], userid: int, db: Session + ) -> Tuple[bool, List[Dict[str, str]]]: + """ + Construct a list of creators from a DOI author group. + + Args: + doi_group_id: ID or name of the DOI author group + userid: User ID of the requesting user + db: Database session + + Returns: + Tuple of (success, creators_list) + """ + from sqlalchemy import and_ + + if isInt(doi_group_id): + # Filter by group ID + authors = ( + db.query(DOIAuthor) + .filter( + and_( + DOIAuthor.author_groupid == int(doi_group_id), + DOIAuthor.author_groupid == DOIAuthorGroup.id, + DOIAuthorGroup.userid == userid, + ) + ) + .order_by(DOIAuthor.id) + .all() + ) + else: + # Filter by group name + authors = ( + db.query(DOIAuthor) + .filter( + and_( + DOIAuthor.author_groupid == DOIAuthorGroup.id, + DOIAuthorGroup.name == doi_group_id, + DOIAuthorGroup.userid == userid, + ) + ) + .order_by(DOIAuthor.id) + .all() + ) + + if len(authors) == 0: + return False, [] + + creators = [] + for a in authors: + a_dict = {"name": a.name, "affiliation": a.affiliation} + if a.orcid: + a_dict["orcid"] = a.orcid + if a.gnd: + a_dict["gnd"] = a.gnd + creators.append(a_dict) + + return True, creators + + @staticmethod + def authors_from_page(form_data): + """ + Create author objects from form data. + + Args: + form_data: Dictionary of form fields + + Returns: + List of DOIAuthor objects + """ + authors = [] + + # Extract authors data from form + author_ids = form_data.getlist("author_id") + author_names = form_data.getlist("author_name") + affiliations = form_data.getlist("affiliation") + orcids = form_data.getlist("orcid") + gnds = form_data.getlist("gnd") + + # Create authors + for aid, an, aff, orc, gnd in zip( + author_ids, author_names, affiliations, orcids, gnds + ): + if str(aid) == "" or str(aid) == "None": + # New author + authors.append(DOIAuthor(name=an, affiliation=aff, orcid=orc, gnd=gnd)) + else: + # Existing author + authors.append( + DOIAuthor(id=int(aid), name=an, affiliation=aff, orcid=orc, gnd=gnd) + ) + + return authors diff --git a/server/db/models/glade.py b/server/db/models/glade.py new file mode 100644 index 00000000..e5b5be46 --- /dev/null +++ b/server/db/models/glade.py @@ -0,0 +1,22 @@ +# `server/db/models/glade.py` +from sqlalchemy import Column, Integer, Float, String +from geoalchemy2 import Geography +from ..database import Base + + +class Glade2P3(Base): + __tablename__ = "glade_2p3" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + pgc_number = Column(Integer) + distance = Column(Float) + distance_error = Column(Float) + redshift = Column(Float) + bmag = Column(Float) + bmag_err = Column(Float) + position = Column(Geography("POINT", srid=4326)) + _2mass_name = Column(String) + gwgc_name = Column(String) + hyperleda_name = Column(String) + sdssdr12_name = Column(String) diff --git a/server/db/models/gw_alert.py b/server/db/models/gw_alert.py new file mode 100644 index 00000000..1bac7aae --- /dev/null +++ b/server/db/models/gw_alert.py @@ -0,0 +1,129 @@ +from sqlalchemy import Column, Integer, Float, String, DateTime, ForeignKey +from ..database import Base +from datetime import datetime +from typing import Dict, Any, Optional + + +class GWAlert(Base): + __tablename__ = "gw_alert" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + datecreated = Column(DateTime, default=datetime.now) + graceid = Column(String(50), nullable=False, index=True) + alternateid = Column(String(50), index=True) + role = Column(String(20), default="observation", index=True) # observation or test + timesent = Column(DateTime) + time_of_signal = Column(DateTime, index=True) + packet_type = Column(Integer) + alert_type = Column(String(50)) + detectors = Column(String(100)) + description = Column(String(500)) + far = Column(Float, default=0, index=True) + skymap_fits_url = Column(String(500)) + distance = Column(Float) + distance_error = Column(Float) + prob_bns = Column(Float) + prob_nsbh = Column(Float) + prob_gap = Column(Float) + prob_bbh = Column(Float) + prob_terrestrial = Column(Float) + prob_hasns = Column(Float) + prob_hasremenant = Column(Float) + group = Column(String(50)) + centralfreq = Column(Float) + duration = Column(Float) + avgra = Column(Float) + avgdec = Column(Float, index=True) + observing_run = Column(String(20)) + pipeline = Column(String(50)) + search = Column(String(50)) + area_50 = Column(Float) + area_90 = Column(Float) + gcn_notice_id = Column(Integer) + ivorn = Column(String(100)) + ext_coinc_observatory = Column(String(50)) + ext_coinc_search = Column(String(50)) + time_difference = Column(Float) + time_coincidence_far = Column(Float) + time_sky_position_coincidence_far = Column(Float) + + def getClassification(self) -> str: + """Get classification based on probabilities.""" + if self.group == "Burst": + return "None (detected as burst)" + + probs = [ + {"prob": self.prob_bns if self.prob_bns else 0.0, "class": "BNS"}, + {"prob": self.prob_nsbh if self.prob_nsbh else 0.0, "class": "NSBH"}, + {"prob": self.prob_bbh if self.prob_bbh else 0.0, "class": "BBH"}, + { + "prob": self.prob_terrestrial if self.prob_terrestrial else 0.0, + "class": "Terrestrial", + }, + {"prob": self.prob_gap if self.prob_gap else 0.0, "class": "Mass Gap"}, + ] + + sorted_probs = sorted( + [x for x in probs if x["prob"] > 0.01], + key=lambda i: i["prob"], + reverse=True, + ) + + classification = "" + for p in sorted_probs: + classification += ( + p["class"] + ": (" + str(round(100 * p["prob"], 1)) + "%) " + ) + + return classification + + @staticmethod + def graceidfromalternate(graceid: str) -> str: + """ + Convert alternate GraceIDs to standard format. + Some GraceIDs might be provided in alternative formats like 'S190425z' instead of 'S190425z'. + This method normalizes them. + + Args: + graceid: The GraceID to normalize + + Returns: + Normalized GraceID + """ + # Map of known aliases (to be expanded as needed) + alias_map = { + # Add specific mappings as discovered + } + + # Check if the graceid is in the alias map + if graceid in alias_map: + return alias_map[graceid] + + # Remove any common prefixes/suffixes + # Here we're just returning the original ID as there's no specific + # transformation logic implemented yet + return graceid + + @staticmethod + def alternatefromgraceid(graceid: str) -> str: + """ + Convert standard GraceIDs to alternate format for specific uses. + + Args: + graceid: The standard GraceID + + Returns: + Alternate format GraceID + """ + # Map of standard to alternate formats (to be expanded as needed) + reverse_alias_map = { + # Add specific mappings as discovered + } + + # Check if the graceid is in the reverse alias map + if graceid in reverse_alias_map: + return reverse_alias_map[graceid] + + # By default, return the original graceid + return graceid diff --git a/server/db/models/gw_galaxy.py b/server/db/models/gw_galaxy.py new file mode 100644 index 00000000..fcd257ff --- /dev/null +++ b/server/db/models/gw_galaxy.py @@ -0,0 +1,88 @@ +from sqlalchemy import Column, Integer, Float, String, DateTime, JSON, Enum +from geoalchemy2 import Geography +from ..database import Base +import datetime +from server.core.enums.gwgalaxyscoretype import GwGalaxyScoreType + + +class GWGalaxy(Base): + """ + Gravitational Wave Galaxy mapping. + Maps gravitational wave events to specific galaxies from catalogs. + """ + + __tablename__ = "gw_galaxy" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + graceid = Column(String) + galaxy_catalog = Column(Integer) + galaxy_catalogid = Column(Integer) + reference = Column(String) + + +class EventGalaxy(Base): + """ + Event to Galaxy mapping. + Maps event IDs to galaxies in catalogs. + """ + + __tablename__ = "event_galaxy" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + graceid = Column(String) + galaxy_catalog = Column(Integer) + galaxy_catalogid = Column(Integer) + + +class GWGalaxyScore(Base): + """ + Gravitational Wave Galaxy Score. + Stores scores for galaxies associated with GW events. + """ + + __tablename__ = "gw_galaxy_score" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + gw_galaxyid = Column(Integer) + score_type = Column(Enum(GwGalaxyScoreType, name="gwgalaxyscoretype")) + score = Column(Float) + + +class GWGalaxyList(Base): + """ + Gravitational Wave Galaxy List. + Represents a list of galaxies associated with a GW event. + """ + + __tablename__ = "gw_galaxy_list" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + graceid = Column(String) + groupname = Column(String) + submitterid = Column(Integer) + reference = Column(String) + alertid = Column(String) + doi_url = Column(String(100)) + doi_id = Column(Integer) + + +class GWGalaxyEntry(Base): + """ + Gravitational Wave Galaxy Entry. + Individual galaxy entries within a galaxy list. + """ + + __tablename__ = "gw_galaxy_entry" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + listid = Column(Integer) + name = Column(String) + score = Column(Float) + position = Column(Geography("POINT", srid=4326)) + rank = Column(Integer) + info = Column(JSON) diff --git a/server/db/models/icecube.py b/server/db/models/icecube.py new file mode 100644 index 00000000..8fa67ca0 --- /dev/null +++ b/server/db/models/icecube.py @@ -0,0 +1,40 @@ +from sqlalchemy import Column, Integer, String, DateTime, Float +from ..database import Base + + +class IceCubeNotice(Base): + __tablename__ = "icecube_notice" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + ref_id = Column(String) + graceid = Column(String) + alert_datetime = Column(DateTime) + datecreated = Column(DateTime) + observation_start = Column(DateTime) + observation_stop = Column(DateTime) + pval_generic = Column(Float) + pval_bayesian = Column(Float) + most_probable_direction_ra = Column(Float) + most_probable_direction_dec = Column(Float) + flux_sens_low = Column(Float) + flux_sens_high = Column(Float) + sens_energy_range_low = Column(Float) + sens_energy_range_high = Column(Float) + + +class IceCubeNoticeCoincEvent(Base): + __tablename__ = "icecube_notice_coinc_event" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + icecube_notice_id = Column(Integer) + datecreated = Column(DateTime) + event_dt = Column(Float) + ra = Column(Float) + dec = Column(Float) + containment_probability = Column(Float) + event_pval_generic = Column(Float) + event_pval_bayesian = Column(Float) + ra_uncertainty = Column(Float) + uncertainty_shape = Column(String) diff --git a/server/db/models/instrument.py b/server/db/models/instrument.py new file mode 100644 index 00000000..1cdbfe6e --- /dev/null +++ b/server/db/models/instrument.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, Integer, String, DateTime, Enum, func +from geoalchemy2 import Geography +import shapely.wkb +from sqlalchemy.ext.hybrid import hybrid_property +from ..database import Base +from server.core.enums.instrumenttype import InstrumentType +import datetime + + +class Instrument(Base): + __tablename__ = "instrument" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + instrument_name = Column(String(64)) + nickname = Column(String(25)) + instrument_type = Column(Enum(InstrumentType, name="instrumenttype")) + datecreated = Column(DateTime) + submitterid = Column(Integer) + + +class FootprintCCD(Base): + __tablename__ = "footprint_ccd" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + instrumentid = Column(Integer) + footprint = Column(Geography("POLYGON", srid=4326)) diff --git a/server/db/models/pointing.py b/server/db/models/pointing.py new file mode 100644 index 00000000..0e0271b7 --- /dev/null +++ b/server/db/models/pointing.py @@ -0,0 +1,87 @@ +from sqlalchemy import Column, Integer, Float, DateTime, Enum, String, and_ +from sqlalchemy.ext.hybrid import hybrid_method +from geoalchemy2 import Geography +from ..database import Base +from server.core.enums.bandpass import Bandpass +from server.core.enums.depthunit import DepthUnit +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.utils.spectral import SpectralRangeHandler + + +class Pointing(Base): + __tablename__ = "pointing" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + status = Column(Enum(pointing_status_enum, name="pointing_status")) + position = Column(Geography("POINT", srid=4326)) + galaxy_catalog = Column(Integer) + galaxy_catalogid = Column(Integer) + instrumentid = Column(Integer) + depth = Column(Float) + depth_err = Column(Float) + depth_unit = Column(Enum(DepthUnit, name="depth_unit")) + time = Column(DateTime) + datecreated = Column(DateTime) + dateupdated = Column(DateTime) + submitterid = Column(Integer) + pos_angle = Column(Float) + band = Column(Enum(Bandpass, name="bandpass")) + doi_url = Column(String(100)) + doi_id = Column(Integer) + central_wave = Column(Float) + bandwidth = Column(Float) + + @hybrid_method + def inSpectralRange(self, spectral_min, spectral_max, spectral_type): + """ + Function to determine if a pointing is within a given range for spectral types: + wavelength (Angstroms) + energy (eV) + frequency (Hz) + + It inputs the range of the spectral type (minimum and maximum values for given type) and + determines if the pointing's observation is in that range. The boolean logic is all + encompassing; whether the endpoints are confined entirely within the provided range + """ + if spectral_type == SpectralRangeHandler.spectralrangetype.wavelength: + thismin, thismax = SpectralRangeHandler.wavetoWaveRange( + self.central_wave, self.bandwidth + ) + elif spectral_type == SpectralRangeHandler.spectralrangetype.energy: + thismin, thismax = SpectralRangeHandler.wavetoEnergy( + self.central_wave, self.bandwidth + ) + elif spectral_type == SpectralRangeHandler.spectralrangetype.frequency: + thismin, thismax = SpectralRangeHandler.wavetoFrequency( + self.central_wave, self.bandwidth + ) + else: + return False + + if thismin >= spectral_min and thismax <= spectral_max: + return True + + return False + + @inSpectralRange.expression + def inSpectralRange(cls, spectral_min, spectral_max, spectral_type): + """ + SQLAlchemy expression version of inSpectralRange for database queries + """ + if spectral_type == SpectralRangeHandler.spectralrangetype.wavelength: + thismin, thismax = SpectralRangeHandler.wavetoWaveRange( + cls.central_wave, cls.bandwidth + ) + elif spectral_type == SpectralRangeHandler.spectralrangetype.energy: + thismin, thismax = SpectralRangeHandler.wavetoEnergy( + cls.central_wave, cls.bandwidth + ) + elif spectral_type == SpectralRangeHandler.spectralrangetype.frequency: + thismin, thismax = SpectralRangeHandler.wavetoFrequency( + cls.central_wave, cls.bandwidth + ) + else: + return False + + return and_(thismin >= spectral_min, thismax <= spectral_max) diff --git a/server/db/models/pointing_event.py b/server/db/models/pointing_event.py new file mode 100644 index 00000000..abb84825 --- /dev/null +++ b/server/db/models/pointing_event.py @@ -0,0 +1,11 @@ +from sqlalchemy import Column, Integer, String +from ..database import Base + + +class PointingEvent(Base): + __tablename__ = "pointing_event" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + pointingid = Column(Integer, index=True) + graceid = Column(String, index=True) diff --git a/server/db/models/users.py b/server/db/models/users.py new file mode 100644 index 00000000..feaa8d30 --- /dev/null +++ b/server/db/models/users.py @@ -0,0 +1,57 @@ +from sqlalchemy import Column, Integer, String, DateTime, Boolean, JSON +from werkzeug.security import generate_password_hash, check_password_hash +from ..database import Base + + +class Users(Base): + __tablename__ = "users" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + username = Column(String(25), index=True, unique=True) + firstname = Column(String(25)) + lastname = Column(String(25)) + password_hash = Column(String(128)) + datecreated = Column(DateTime) + email = Column(String(100)) + api_token = Column(String(128)) + verification_key = Column(String(128)) + verified = Column(Boolean) + + def set_password(self, password): + self.password_hash = generate_password_hash(password) + + def check_password(self, password): + return check_password_hash(self.password_hash, password) + + +class UserGroups(Base): + __tablename__ = "usergroups" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + userid = Column(Integer) + groupid = Column(Integer) + role = Column(String(25)) + + +class Groups(Base): + __tablename__ = "groups" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + name = Column(String(25)) + datecreated = Column(DateTime) + + +class UserActions(Base): + __tablename__ = "useractions" + __table_args__ = {"schema": "public"} + + id = Column(Integer, primary_key=True) + userid = Column(Integer) + ipaddress = Column(String(50)) + url = Column(String()) + time = Column(DateTime) + jsonvals = Column(JSON) + method = Column(String(24)) diff --git a/server/db/utils.py b/server/db/utils.py new file mode 100644 index 00000000..f19c4116 --- /dev/null +++ b/server/db/utils.py @@ -0,0 +1,10 @@ +from .database import SessionLocal + + +# Dependency to get a database session +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/server/main.py b/server/main.py new file mode 100644 index 00000000..dd20a2b8 --- /dev/null +++ b/server/main.py @@ -0,0 +1,262 @@ +import os + +from fastapi import FastAPI, Request, status, Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import Session + +import datetime +import uvicorn +import logging +import redis + +from server.config import settings +from server.db.database import get_db + +from server.routes.pointing.router import router as pointing_router +from server.routes.instrument.router import router as instrument_router +from server.routes.admin.router import router as admin_router +from server.routes.candidate.router import router as candidate_router +from server.routes.doi.router import router as doi_router +from server.routes.gw_alert.router import router as gw_alert_router +from server.routes.gw_galaxy.router import router as galaxy_router +from server.routes.icecube.router import router as icecube_router +from server.routes.event.router import router as event +from server.routes.ui.router import router as ui_router + +from contextlib import asynccontextmanager +from server.utils.error_handling import ErrorDetail + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +app = FastAPI( + title=settings.APP_NAME, + description="Gravitational-Wave Treasure Map API", + version="1.0.0", + debug=settings.DEBUG, +) + +# Define API version prefix +API_V1_PREFIX = "/api/v1" + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=settings.CORS_METHODS, + allow_headers=settings.CORS_HEADERS, +) + + +async def lifespan_middleware(request: Request, call_next): + try: + # Initialize response here to avoid UnboundLocalError + response = None + # Your middleware logic + response = await call_next(request) + return response + except Exception as e: + # Error handling logic + # Make sure response is defined even in exception paths + if response is None: + # Create a default error response + response = JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) + return response + + +@asynccontextmanager +async def lifespan_context(): + logger.info("Application is starting up...") + try: + yield + except Exception as e: + logger.error(f"An error occurred: {e}") + finally: + logger.info("Application is shutting down...") + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle Pydantic validation errors""" + errors = [] + for error in exc.errors(): + errors.append( + ErrorDetail( + message=error["msg"], + code="validation_error", + params={ + "field": ( + ".".join(str(x) for x in error["loc"]) if error["loc"] else None + ), + "type": error["type"], + }, + ).to_dict() + ) + + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"message": "Request validation error", "errors": errors}, + ) + + +@app.exception_handler(SQLAlchemyError) +async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): + """Handle SQLAlchemy errors""" + # Log the exception details for debugging + logger.error(f"Database error: {str(exc)}") + + # Don't expose internal details to the client + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"message": "A database error occurred"}, + ) + + +@app.exception_handler(IntegrityError) +async def integrity_exception_handler(request: Request, exc: IntegrityError): + """Handle database integrity errors""" + logger.error(f"Integrity error: {str(exc)}") + + return JSONResponse( + status_code=status.HTTP_409_CONFLICT, + content={"message": "The request conflicts with database constraints"}, + ) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + """Custom handler for HTTPException to ensure consistent format""" + content = exc.detail + + # Ensure consistent format if detail is just a string + if isinstance(content, str): + content = {"message": content} + + return JSONResponse( + status_code=exc.status_code, headers=exc.headers, content=content + ) + + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """Catch-all handler for unhandled exceptions""" + # Log the full exception details for debugging + logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) + + # Don't expose internal details to the client + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"message": "An unexpected error occurred"}, + ) + + +# API health check +@app.get("/health") +async def health(): + return { + "status": "ok", + "time": datetime.datetime.now(datetime.timezone.utc).isoformat(), + } + + +@app.get("/service-status") +async def service_status(db: Session = Depends(get_db)): + """ + Detailed service status endpoint that checks database and Redis connections. + + Returns: + Dict with status of database and Redis connections, plus detailed info + """ + status = { + "database_status": "unknown", + "redis_status": "unknown", + "details": {"database": {}, "redis": {}}, + } + + # Check database connection with detailed info + try: + # Get connection parameters from settings + db_host = settings.DB_HOST + db_port = settings.DB_PORT + db_name = settings.DB_NAME + + # Store connection info + status["details"]["database"] = { + "host": db_host, + "port": db_port, + "name": db_name, + } + + # Test actual connection + result = db.execute("SELECT 1").first() + if result and result[0] == 1: + status["database_status"] = "connected" + else: + status["database_status"] = "disconnected" + except Exception as e: + status["database_status"] = "disconnected" + status["details"]["database"]["error"] = str(e) + + # Check Redis connection with detailed info + try: + # Get Redis connection parameters + redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0") + + # Parse the URL for debug info + if redis_url.startswith("redis://"): + redis_host = redis_url.split("redis://")[1].split(":")[0] + redis_port = redis_url.split(":")[-1].split("/")[0] + else: + redis_host = "unknown" + redis_port = "unknown" + + # Store connection info + status["details"]["redis"] = { + "host": redis_host, + "port": redis_port, + "url": redis_url, + } + + # Test actual connection + try: + redis_client = redis.from_url(redis_url) + if redis_client.ping(): + status["redis_status"] = "connected" + else: + status["redis_status"] = "disconnected" + except redis.exceptions.ConnectionError: + status["redis_status"] = "disconnected" + status["details"]["redis"]["error"] = "Connection refused" + except Exception as e: + status["redis_status"] = "disconnected" + status["details"]["redis"]["error"] = str(e) + + return status + + +# Include routers with the API prefix +app.include_router(pointing_router, prefix=API_V1_PREFIX) +app.include_router(gw_alert_router, prefix=API_V1_PREFIX) +app.include_router(candidate_router, prefix=API_V1_PREFIX) +app.include_router(instrument_router, prefix=API_V1_PREFIX) +app.include_router(galaxy_router, prefix=API_V1_PREFIX) +app.include_router(icecube_router, prefix=API_V1_PREFIX) +app.include_router(doi_router, prefix=API_V1_PREFIX) +app.include_router(event, prefix=API_V1_PREFIX) + +# Include admin router without API prefix (matches original endpoint) +app.include_router(admin_router) + +# Include UI-specific routes without the API prefix +app.include_router(ui_router) + +if __name__ == "__main__": + uvicorn.run("server.main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 00000000..7de0e195 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,39 @@ +fastapi>=0.103.0 +uvicorn[standard]>=0.23.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 +email-validator>=2.0.0 +sqlalchemy>=2.0.0 +geoalchemy2>=0.13.0 +alembic>=1.12.0 +psycopg2-binary>=2.9.0 +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 +python-multipart>=0.0.6 +jinja2>=3.1.2 +astropy>=5.3.0 +healpy>=1.16.0 +numpy>=1.24.0 +pandas>=2.0.0 +scipy>=1.10.0 +shapely>=2.0.0 +pytz>=2023.3 +requests>=2.31.0 +validators>=0.20.0 +boto3>=1.26.90 +redis>=4.6.0 +mocpy>=0.12.0 +pygcn>=0.1.8 +# Fixed version compatibility for storage libraries +fsspec>=2021.10.1,<2023.0.0 +adlfs>=0.7.9,<2023.0.0 +s3fs==0.6.0 +python-dotenv>=1.0.0 +aiofiles>=0.8.0 +pytest>=7.3.1 +httpx>=0.24.1 +Werkzeug==2.0.3 +ephem +plotly>=5.15.0 +kaleido>=0.2.1 + diff --git a/server/routes/__init__.py b/server/routes/__init__.py new file mode 100644 index 00000000..8a913564 --- /dev/null +++ b/server/routes/__init__.py @@ -0,0 +1 @@ +"""Routes package for the GWTM FastAPI application.""" diff --git a/server/routes/admin/__init__.py b/server/routes/admin/__init__.py new file mode 100644 index 00000000..27163dc2 --- /dev/null +++ b/server/routes/admin/__init__.py @@ -0,0 +1 @@ +"""Admin-related routes.""" diff --git a/server/routes/admin/fixdata.py b/server/routes/admin/fixdata.py new file mode 100644 index 00000000..9768e3f9 --- /dev/null +++ b/server/routes/admin/fixdata.py @@ -0,0 +1,24 @@ +"""Fix data endpoint for admin users.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.auth.auth import verify_admin + +router = APIRouter(tags=["admin"]) + + +@router.get("/fixdata") +@router.post("/fixdata") +async def fixdata( + db: Session = Depends(get_db), + user=Depends(verify_admin), # Only admin can use this endpoint +): + """ + Fix data issues (admin only). + + This endpoint is for administrative purposes. + For backward compatibility, it just verifies admin access and returns success. + """ + return {"message": "success"} diff --git a/server/routes/admin/router.py b/server/routes/admin/router.py new file mode 100644 index 00000000..43aa7061 --- /dev/null +++ b/server/routes/admin/router.py @@ -0,0 +1,12 @@ +"""Consolidated router for all admin endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .fixdata import router as fixdata_router + +# Create the main router that includes all admin routes +router = APIRouter(tags=["admin"]) + +# Include all the individual routers +router.include_router(fixdata_router) diff --git a/server/routes/candidate/__init__.py b/server/routes/candidate/__init__.py new file mode 100644 index 00000000..4db6b86f --- /dev/null +++ b/server/routes/candidate/__init__.py @@ -0,0 +1 @@ +"""Candidate-related routes.""" diff --git a/server/routes/candidate/create_candidates.py b/server/routes/candidate/create_candidates.py new file mode 100644 index 00000000..e6c502d0 --- /dev/null +++ b/server/routes/candidate/create_candidates.py @@ -0,0 +1,86 @@ +"""Create candidates endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +import shapely.geometry + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import PostCandidateRequest, CandidateResponse +from server.auth.auth import get_current_user +from server.utils.error_handling import validation_exception + +router = APIRouter(tags=["candidates"]) + + +@router.post("/candidate", response_model=CandidateResponse) +async def post_gw_candidates( + post_request: PostCandidateRequest, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Post new candidate(s) for a GW event. + + This endpoint accepts either a single candidate or multiple candidates + for a gravitational wave event. + """ + + # Validate that the graceid exists + valid_alerts = ( + db.query(GWAlert).filter(GWAlert.graceid == post_request.graceid).all() + ) + if len(valid_alerts) == 0: + raise validation_exception( + "Invalid 'graceid'. Visit https://treasuremap.space/alert_select for valid alerts" + ) + + errors = [] + warnings = [] + valid_candidates = [] + candidate_ids = [] + + # Process single candidate + if post_request.candidate: + valid_candidates.append(post_request.candidate) + + # Process multiple candidates + elif post_request.candidates: + for candidate in post_request.candidates: + valid_candidates.append(candidate) + + for candidate in valid_candidates: + + # Validate the candidate + new_candidate = GWCandidate( + graceid=post_request.graceid, + submitterid=user.id, + candidate_name=candidate.candidate_name, + tns_name=candidate.tns_name, + tns_url=candidate.tns_url, + position=( + shapely.geometry.Point(candidate.ra, candidate.dec).wkt + if candidate.ra is not None and candidate.dec is not None + else candidate.position + ), + discovery_date=candidate.discovery_date, + discovery_magnitude=candidate.discovery_magnitude, + magnitude_central_wave=candidate.magnitude_central_wave, + magnitude_bandwidth=candidate.magnitude_bandwidth, + magnitude_unit=candidate.magnitude_unit, + magnitude_bandpass=candidate.magnitude_bandpass, + associated_galaxy=candidate.associated_galaxy, + associated_galaxy_redshift=candidate.associated_galaxy_redshift, + associated_galaxy_distance=candidate.associated_galaxy_distance, + ) + + db.add(new_candidate) + db.flush() + candidate_ids.append(new_candidate.id) + + db.commit() + + return CandidateResponse( + candidate_ids=candidate_ids, ERRORS=errors, WARNINGS=warnings + ) diff --git a/server/routes/candidate/delete_candidates.py b/server/routes/candidate/delete_candidates.py new file mode 100644 index 00000000..b581a1e6 --- /dev/null +++ b/server/routes/candidate/delete_candidates.py @@ -0,0 +1,89 @@ +"""Delete candidates endpoint.""" + +from fastapi import APIRouter, Depends, Body +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import DeleteCandidateParams, DeleteCandidateResponse +from server.auth.auth import get_current_user +from server.utils.error_handling import ( + not_found_exception, + permission_exception, + validation_exception, +) + +router = APIRouter(tags=["candidates"]) + + +@router.delete("/candidate", response_model=DeleteCandidateResponse) +async def delete_candidates( + delete_params: DeleteCandidateParams = Body(..., description="Fields to delete"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Delete candidate(s). + + Provide either: + - A single candidate ID to delete + - A list of candidate IDs to delete + + Only the owner of a candidate can delete it. + Returns information about deleted candidates and any warnings. + """ + warnings = [] + candidates_to_delete = [] + + # Handle single ID + if delete_params.id is not None: + candidate = ( + db.query(GWCandidate).filter(GWCandidate.id == delete_params.id).first() + ) + if not candidate: + raise not_found_exception( + f"No candidate found with 'id': {delete_params.id}" + ) + + if candidate.submitterid != user.id: + raise permission_exception( + "Error: Unauthorized. Unable to alter other user's records" + ) + + candidates_to_delete.append(candidate) + + # Handle multiple IDs + elif delete_params.ids is not None: + query_ids = delete_params.ids + candidates = db.query(GWCandidate).filter(GWCandidate.id.in_(query_ids)).all() + + if len(candidates) == 0: + raise not_found_exception("No candidates found with provided 'ids'") + + # Filter candidates the user is allowed to delete + candidates_to_delete.extend([x for x in candidates if x.submitterid == user.id]) + if len(candidates_to_delete) < len(candidates): + warnings.append( + "Some entries were not deleted. You cannot delete candidates you didn't submit" + ) + + else: + raise validation_exception( + message="Missing required parameter", + errors=["Either 'id' or 'ids' parameter is required"], + ) + + # Delete the candidates + if len(candidates_to_delete): + del_ids = [] + for ctd in candidates_to_delete: + del_ids.append(ctd.id) + db.delete(ctd) + + db.commit() + + return DeleteCandidateResponse( + message=f"Successfully deleted {len(candidates_to_delete)} candidate(s)", + deleted_ids=del_ids, + warnings=warnings, + ) diff --git a/server/routes/candidate/get_candidates.py b/server/routes/candidate/get_candidates.py new file mode 100644 index 00000000..453ec8d1 --- /dev/null +++ b/server/routes/candidate/get_candidates.py @@ -0,0 +1,126 @@ +"""Get candidates endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List +from dateutil.parser import parse as date_parse +import shapely.wkb + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import CandidateSchema, GetCandidateQueryParams +from server.auth.auth import get_current_user + +router = APIRouter(tags=["candidates"]) + + +@router.get("/candidate", response_model=List[CandidateSchema]) +async def get_candidates( + query_params: GetCandidateQueryParams = Depends(), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get candidates with optional filters. + """ + filter_conditions = [] + + if query_params.id: + filter_conditions.append(GWCandidate.id == query_params.id) + + if query_params.ids: + try: + ids_list = None + if isinstance(query_params.ids, str): + ids_list = query_params.ids.split("[")[1].split("]")[0].split(",") + elif isinstance(query_params.ids, list): + ids_list = query_params.ids + if ids_list: + filter_conditions.append(GWCandidate.id.in_(ids_list)) + except: + pass + + if query_params.graceid: + graceid = GWAlert.graceidfromalternate(query_params.graceid) + filter_conditions.append(GWCandidate.graceid == graceid) + + if query_params.userid: + filter_conditions.append(GWCandidate.submitterid == query_params.userid) + + if query_params.submitted_date_after: + try: + parsed_date_after = date_parse(query_params.submitted_date_after) + filter_conditions.append(GWCandidate.datecreated >= parsed_date_after) + except: + pass + + if query_params.submitted_date_before: + try: + parsed_date_before = date_parse(query_params.submitted_date_before) + filter_conditions.append(GWCandidate.datecreated <= parsed_date_before) + except: + pass + + if query_params.discovery_magnitude_gt is not None: + filter_conditions.append( + GWCandidate.discovery_magnitude >= query_params.discovery_magnitude_gt + ) + + if query_params.discovery_magnitude_lt is not None: + filter_conditions.append( + GWCandidate.discovery_magnitude <= query_params.discovery_magnitude_lt + ) + + if query_params.discovery_date_after: + try: + parsed_date_after = date_parse(query_params.discovery_date_after) + filter_conditions.append(GWCandidate.discovery_date >= parsed_date_after) + except: + pass + + if query_params.discovery_date_before: + try: + parsed_date_before = date_parse(query_params.discovery_date_before) + filter_conditions.append(GWCandidate.discovery_date <= parsed_date_before) + except: + pass + + if query_params.associated_galaxy_name: + filter_conditions.append( + GWCandidate.associated_galaxy.contains(query_params.associated_galaxy_name) + ) + + if query_params.associated_galaxy_redshift_gt is not None: + filter_conditions.append( + GWCandidate.associated_galaxy_redshift + >= query_params.associated_galaxy_redshift_gt + ) + + if query_params.associated_galaxy_redshift_lt is not None: + filter_conditions.append( + GWCandidate.associated_galaxy_redshift + <= query_params.associated_galaxy_redshift_lt + ) + + if query_params.associated_galaxy_distance_gt is not None: + filter_conditions.append( + GWCandidate.associated_galaxy_distance + >= query_params.associated_galaxy_distance_gt + ) + + if query_params.associated_galaxy_distance_lt is not None: + filter_conditions.append( + GWCandidate.associated_galaxy_distance + <= query_params.associated_galaxy_distance_lt + ) + + candidates = db.query(GWCandidate).filter(*filter_conditions).all() + + for candidate in candidates: + # Convert position from WKB to WKT + if candidate.position: + position = shapely.wkb.loads(bytes(candidate.position.data)) + candidate.position = str(position) + + return candidates diff --git a/server/routes/candidate/router.py b/server/routes/candidate/router.py new file mode 100644 index 00000000..9b4dd80b --- /dev/null +++ b/server/routes/candidate/router.py @@ -0,0 +1,18 @@ +"""Consolidated router for all candidate endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .get_candidates import router as get_candidates_router +from .create_candidates import router as create_candidates_router +from .update_candidate import router as update_candidate_router +from .delete_candidates import router as delete_candidates_router + +# Create the main router that includes all candidate routes +router = APIRouter(tags=["candidates"]) + +# Include all the individual routers +router.include_router(get_candidates_router) +router.include_router(create_candidates_router) +router.include_router(update_candidate_router) +router.include_router(delete_candidates_router) diff --git a/server/routes/candidate/update_candidate.py b/server/routes/candidate/update_candidate.py new file mode 100644 index 00000000..5a9163bf --- /dev/null +++ b/server/routes/candidate/update_candidate.py @@ -0,0 +1,88 @@ +"""Update candidate endpoint.""" + +from fastapi import APIRouter, Depends, Body +from sqlalchemy.orm import Session +import shapely.geometry +import shapely.wkb + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import PutCandidateRequest, CandidateUpdateField +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception, permission_exception + +router = APIRouter(tags=["candidates"]) + + +@router.put("/candidate", response_model=PutCandidateRequest) +async def update_candidate( + request: PutCandidateRequest = Body(..., description="Fields to update"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Update an existing candidate. + + Only the owner of the candidate can update it. + Returns either a success response with the updated candidate or a failure response with errors. + """ + # Find the candidate + candidate = db.query(GWCandidate).filter(GWCandidate.id == request.id).first() + + if not candidate: + raise not_found_exception(f"No candidate found with id: {request.id}") + + # Check permissions + if candidate.submitterid != user.id: + raise permission_exception("Unable to alter other user's candidate records") + + update = request.candidate.dict(exclude_unset=True) + # Copy values from the Pydantic schema to the SQLAlchemy model + for key, value in update.items(): + if hasattr(candidate, key): + setattr(candidate, key, value) + + position = candidate.position + + # Update ra or dec in the wkt string position + if "ra" in update or "dec" in update: + if update["ra"] is not None and update["dec"] is not None: + position = shapely.geometry.Point(update["ra"], update["dec"]).wkt + candidate.position = position + elif update["ra"] is not None: + position = shapely.geometry.Point(update["ra"], candidate.dec).wkt + candidate.position = position + elif update["dec"] is not None: + position = shapely.geometry.Point(candidate.ra, update["dec"]).wkt + candidate.position = position + + db.commit() + db.refresh(candidate) + + # copy the updated candidate to CandidateUpdateField object + candidate_dict = { + "graceid": candidate.graceid, + "submitterid": candidate.submitterid, + "candidate_name": candidate.candidate_name, + "datecreated": candidate.datecreated, + "tns_name": candidate.tns_name, + "tns_url": candidate.tns_url, + # convert position from wkb to wkt and then to string + "position": str(shapely.wkb.loads(bytes(candidate.position.data))), + # convert discovery_date to string + "discovery_date": ( + candidate.discovery_date.isoformat() if candidate.discovery_date else None + ), + "discovery_magnitude": candidate.discovery_magnitude, + "magnitude_central_wave": candidate.magnitude_central_wave, + "magnitude_bandwidth": candidate.magnitude_bandwidth, + "magnitude_unit": candidate.magnitude_unit, + "magnitude_bandpass": candidate.magnitude_bandpass, + "associated_galaxy": candidate.associated_galaxy, + "associated_galaxy_redshift": candidate.associated_galaxy_redshift, + "associated_galaxy_distance": candidate.associated_galaxy_distance, + } + # Convert to CandidateUpdateField instance + updated_candidate = CandidateUpdateField(**candidate_dict) + + return PutCandidateRequest(id=request.id, candidate=updated_candidate) diff --git a/server/routes/doi/__init__.py b/server/routes/doi/__init__.py new file mode 100644 index 00000000..004a28b9 --- /dev/null +++ b/server/routes/doi/__init__.py @@ -0,0 +1 @@ +"""DOI-related routes.""" diff --git a/server/routes/doi/get_author_groups.py b/server/routes/doi/get_author_groups.py new file mode 100644 index 00000000..85638623 --- /dev/null +++ b/server/routes/doi/get_author_groups.py @@ -0,0 +1,26 @@ +"""Get DOI author groups endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List + +from server.db.database import get_db +from server.db.models.doi_author import DOIAuthorGroup +from server.auth.auth import get_current_user +from server.schemas.doi import DOIAuthorGroupSchema + +router = APIRouter(tags=["DOI"]) + + +@router.get("/doi_author_groups", response_model=List[DOIAuthorGroupSchema]) +async def get_doi_author_groups( + db: Session = Depends(get_db), user=Depends(get_current_user) +): + """ + Get all DOI author groups for the current user. + + Returns: + - List of DOI author groups + """ + groups = db.query(DOIAuthorGroup).filter(DOIAuthorGroup.userid == user.id).all() + return groups diff --git a/server/routes/doi/get_authors.py b/server/routes/doi/get_authors.py new file mode 100644 index 00000000..d15136fc --- /dev/null +++ b/server/routes/doi/get_authors.py @@ -0,0 +1,42 @@ +"""Get DOI authors endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List + +from server.db.database import get_db +from server.db.models.doi_author import DOIAuthor, DOIAuthorGroup +from server.auth.auth import get_current_user +from server.schemas.doi import DOIAuthorSchema +from server.utils.error_handling import permission_exception + +router = APIRouter(tags=["DOI"]) + + +@router.get("/doi_authors/{group_id}", response_model=List[DOIAuthorSchema]) +async def get_doi_authors( + group_id: int, db: Session = Depends(get_db), user=Depends(get_current_user) +): + """ + Get all DOI authors for a specific group. + + Parameters: + - group_id: DOI author group ID + + Returns: + - List of DOI authors + """ + # First check if the group belongs to the user + group = ( + db.query(DOIAuthorGroup) + .filter(DOIAuthorGroup.id == group_id, DOIAuthorGroup.userid == user.id) + .first() + ) + + if not group: + raise permission_exception( + "You don't have permission to access this DOI author group" + ) + + authors = db.query(DOIAuthor).filter(DOIAuthor.author_groupid == group_id).all() + return authors diff --git a/server/routes/doi/get_doi_pointings.py b/server/routes/doi/get_doi_pointings.py new file mode 100644 index 00000000..c33f9f8a --- /dev/null +++ b/server/routes/doi/get_doi_pointings.py @@ -0,0 +1,73 @@ +"""Get DOI pointings endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.db.models.pointing_event import PointingEvent +from server.auth.auth import get_current_user +from server.schemas.doi import DOIPointingInfo, DOIPointingsResponse + +router = APIRouter(tags=["DOI"]) + + +@router.get("/doi_pointings", response_model=DOIPointingsResponse) +async def get_doi_pointings( + db: Session = Depends(get_db), user=Depends(get_current_user) +): + """ + Get all pointings with DOIs requested by the current user. + + Returns: + - List of pointings with DOI information + """ + # Query pointings with DOIs, ensuring we only get pointings that actually have DOI information + pointings = ( + db.query(Pointing) + .filter( + Pointing.submitterid == user.id, + Pointing.doi_id.isnot( + None + ), # Changed from != None to .isnot(None) for proper SQLAlchemy syntax + Pointing.doi_url.isnot(None), # Also check that doi_url is not None + ) + .all() + ) + + result = [] + for pointing in pointings: + # Get event information - need to join with PointingEvent to get graceid + pointing_events = ( + db.query(PointingEvent) + .filter(PointingEvent.pointingid == pointing.id) + .all() + ) + graceid = pointing_events[0].graceid if pointing_events else "Unknown" + + # Get instrument information + instrument = ( + db.query(Instrument).filter(Instrument.id == pointing.instrumentid).first() + ) + instrument_name = instrument.instrument_name if instrument else "Unknown" + + # Convert status enum to string if needed + status_str = ( + pointing.status.name + if hasattr(pointing.status, "name") + else str(pointing.status) + ) + + result.append( + DOIPointingInfo( + id=pointing.id, + graceid=graceid, + instrument_name=instrument_name, + status=status_str, + doi_url=pointing.doi_url, + doi_id=pointing.doi_id, + ) + ) + + return DOIPointingsResponse(pointings=result) diff --git a/server/routes/doi/router.py b/server/routes/doi/router.py new file mode 100644 index 00000000..30e3e2ef --- /dev/null +++ b/server/routes/doi/router.py @@ -0,0 +1,16 @@ +"""Consolidated router for all DOI endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .get_doi_pointings import router as get_doi_pointings_router +from .get_author_groups import router as get_author_groups_router +from .get_authors import router as get_authors_router + +# Create the main router that includes all DOI routes +router = APIRouter(tags=["DOI"]) + +# Include all the individual routers +router.include_router(get_doi_pointings_router) +router.include_router(get_author_groups_router) +router.include_router(get_authors_router) diff --git a/server/routes/event/__init__.py b/server/routes/event/__init__.py new file mode 100644 index 00000000..229c16dc --- /dev/null +++ b/server/routes/event/__init__.py @@ -0,0 +1 @@ +"""Event route module.""" diff --git a/server/routes/event/create_candidate_event.py b/server/routes/event/create_candidate_event.py new file mode 100644 index 00000000..2cf56129 --- /dev/null +++ b/server/routes/event/create_candidate_event.py @@ -0,0 +1,39 @@ +"""Create candidate event endpoint.""" + +from datetime import datetime +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import GWCandidateCreate +from server.core.enums.depthunit import DepthUnit as depth_unit_enum +from server.auth.auth import get_current_user + +router = APIRouter(tags=["Events"]) + + +@router.post("/candidate/event") +async def create_candidate_event( + candidate: GWCandidateCreate, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Create a new candidate event.""" + # Create a POINT WKT string for position + position = f"POINT({candidate.ra} {candidate.dec})" + + new_candidate = GWCandidate( + candidate_name=candidate.candidate_name, + graceid=candidate.graceid, # Required field for GWCandidate + position=position, # Set position from ra and dec + submitterid=current_user.id, + magnitude_unit=depth_unit_enum.ab_mag, # Default required field + datecreated=datetime.now(), + ) + + db.add(new_candidate) + db.commit() + db.refresh(new_candidate) + + return {"message": "Candidate created successfully", "id": new_candidate.id} diff --git a/server/routes/event/delete_candidate_event.py b/server/routes/event/delete_candidate_event.py new file mode 100644 index 00000000..d9b94182 --- /dev/null +++ b/server/routes/event/delete_candidate_event.py @@ -0,0 +1,34 @@ +"""Delete candidate event endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.utils.error_handling import not_found_exception, permission_exception +from server.auth.auth import get_current_user +from .utils import is_admin + +router = APIRouter(tags=["Events"]) + + +@router.delete("/candidate/event/{candidate_id}") +async def delete_candidate_event( + candidate_id: int, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Delete a candidate event.""" + db_candidate = db.query(GWCandidate).filter(GWCandidate.id == candidate_id).first() + + if not db_candidate: + raise not_found_exception("Candidate not found") + + # Check if user is the owner or an admin + if db_candidate.submitterid != current_user.id and not is_admin(current_user, db): + raise permission_exception("Not authorized to delete this candidate") + + db.delete(db_candidate) + db.commit() + + return {"message": "Candidate deleted successfully"} diff --git a/server/routes/event/get_candidate_events.py b/server/routes/event/get_candidate_events.py new file mode 100644 index 00000000..beb1605b --- /dev/null +++ b/server/routes/event/get_candidate_events.py @@ -0,0 +1,32 @@ +"""Get candidate events endpoint.""" + +from typing import List, Optional +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import GWCandidateSchema +from server.auth.auth import get_current_user + +router = APIRouter(tags=["Events"]) + + +@router.get("/candidate/event", response_model=List[GWCandidateSchema]) +async def get_candidate_events( + id: Optional[int] = Query(None, description="Filter by candidate ID"), + user_id: Optional[int] = Query(None, description="Filter by user ID"), + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Get list of candidate events, optionally filtered by user or ID.""" + query = db.query(GWCandidate) + + if id: + query = query.filter(GWCandidate.id == id) + + if user_id: + query = query.filter(GWCandidate.submitterid == user_id) + + candidates = query.all() + return candidates diff --git a/server/routes/event/router.py b/server/routes/event/router.py new file mode 100644 index 00000000..09aa4269 --- /dev/null +++ b/server/routes/event/router.py @@ -0,0 +1,18 @@ +"""Consolidated router for all event endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .get_candidate_events import router as get_candidate_events_router +from .create_candidate_event import router as create_candidate_event_router +from .update_candidate_event import router as update_candidate_event_router +from .delete_candidate_event import router as delete_candidate_event_router + +# Create the main router that includes all event routes +router = APIRouter(tags=["Events"]) + +# Include all the individual routers +router.include_router(get_candidate_events_router) +router.include_router(create_candidate_event_router) +router.include_router(update_candidate_event_router) +router.include_router(delete_candidate_event_router) diff --git a/server/routes/event/update_candidate_event.py b/server/routes/event/update_candidate_event.py new file mode 100644 index 00000000..aba0cd7f --- /dev/null +++ b/server/routes/event/update_candidate_event.py @@ -0,0 +1,44 @@ +"""Update candidate event endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.schemas.candidate import GWCandidateSchema +from server.utils.error_handling import not_found_exception, permission_exception +from server.auth.auth import get_current_user +from .utils import is_admin + +router = APIRouter(tags=["Events"]) + + +@router.put("/candidate/event/{candidate_id}") +async def update_candidate_event( + candidate_id: int, + candidate: GWCandidateSchema, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Update an existing candidate event.""" + db_candidate = db.query(GWCandidate).filter(GWCandidate.id == candidate_id).first() + + if not db_candidate: + raise not_found_exception("Candidate not found") + + # Check if user is the owner or an admin + if db_candidate.submitterid != current_user.id and not is_admin(current_user, db): + raise permission_exception("Not authorized to update this candidate") + + # Update fields + db_candidate.candidate_name = candidate.candidate_name + + # Update position from ra and dec if provided + if candidate.ra is not None and candidate.dec is not None: + position = f"POINT({candidate.ra} {candidate.dec})" + db_candidate.position = position + + db.commit() + db.refresh(db_candidate) + + return {"message": "Candidate updated successfully"} diff --git a/server/routes/event/utils.py b/server/routes/event/utils.py new file mode 100644 index 00000000..51c88ebd --- /dev/null +++ b/server/routes/event/utils.py @@ -0,0 +1,19 @@ +"""Utility functions for event routes.""" + +from sqlalchemy.orm import Session +from server.db.models.users import UserGroups, Groups + + +def is_admin(user, db: Session) -> bool: + """Check if the user is an admin.""" + admin_group = db.query(Groups).filter(Groups.name == "admin").first() + if not admin_group: + return False + + user_group = ( + db.query(UserGroups) + .filter(UserGroups.userid == user.id, UserGroups.groupid == admin_group.id) + .first() + ) + + return user_group is not None diff --git a/server/routes/gw_alert/__init__.py b/server/routes/gw_alert/__init__.py new file mode 100644 index 00000000..2129d4e7 --- /dev/null +++ b/server/routes/gw_alert/__init__.py @@ -0,0 +1 @@ +"""GW alert route module.""" diff --git a/server/routes/gw_alert/delete_test_alerts.py b/server/routes/gw_alert/delete_test_alerts.py new file mode 100644 index 00000000..549b5591 --- /dev/null +++ b/server/routes/gw_alert/delete_test_alerts.py @@ -0,0 +1,139 @@ +"""Delete test alerts endpoint.""" + +from datetime import datetime, timedelta +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.auth.auth import verify_admin +from server.utils.gwtm_io import list_gwtm_bucket, delete_gwtm_files +from server.config import Settings as settings +from server.utils.function import by_chunk +from server.db.models.gw_alert import GWAlert +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_galaxy import GWGalaxyEntry +from server.db.models.candidate import GWCandidate + +router = APIRouter(tags=["gw_alerts"]) + + +@router.post("/del_test_alerts") +async def del_test_alerts( + db: Session = Depends(get_db), + user=Depends(verify_admin), # Only admin can delete test alerts +): + """ + Delete test alerts (admin only). + + This endpoint removes test alerts from the database and related storage. + """ + # Set up filter conditions + filter = [] + testids = [] + alert_to_keep = "MS181101ab" + filter.append(~GWAlert.graceid.contains(alert_to_keep)) + + # Add date-based test IDs to exclusion list + for td in [-1, 0, 1]: + dd = datetime.now() + timedelta(days=td) + yy = str(dd.year)[2:4] + mm = dd.month if dd.month >= 10 else f"0{dd.month}" + dd = dd.day if dd.day >= 10 else f"0{dd.day}" + graceidlike = f"MS{yy}{mm}{dd}" + + testids.append(graceidlike) + filter.append(~GWAlert.graceid.contains(graceidlike)) + + # Add the alert to keep to both the filter and the testids list + filter.append(~GWAlert.graceid.contains(alert_to_keep)) + testids.append(alert_to_keep) + + # Only delete test alerts + filter.append(GWAlert.role == "test") + + # Query for all test alerts that aren't like the ones we want to keep + gwalerts = db.query(GWAlert).filter(*filter).all() + gids_to_rm = [x.graceid for x in gwalerts] + + # Query for pointings and pointing events from graceids + pointing_events = ( + db.query(PointingEvent).filter(PointingEvent.graceid.in_(gids_to_rm)).all() + ) + pointing_ids = [x.pointingid for x in pointing_events] + pointings = db.query(Pointing).filter(Pointing.id.in_(pointing_ids)).all() + + # Query for galaxy lists and galaxy list entries from graceids + try: + from server.db.models.gw_alert import GWGalaxyList + + galaxylists = ( + db.query(GWGalaxyList).filter(GWGalaxyList.graceid.in_(gids_to_rm)).all() + ) + galaxylist_ids = [x.id for x in galaxylists] + galaxyentries = ( + db.query(GWGalaxyEntry) + .filter(GWGalaxyEntry.listid.in_(galaxylist_ids)) + .all() + ) + except ImportError: + # If the model isn't available, create empty lists + galaxylists = [] + galaxyentries = [] + + # Query for candidates to delete + candidates = db.query(GWCandidate).filter(GWCandidate.graceid.in_(gids_to_rm)).all() + + # Delete in order (to avoid foreign key constraints) + if len(candidates) > 0: + for c in candidates: + db.delete(c) + + if len(galaxyentries) > 0: + for ge in galaxyentries: + db.delete(ge) + + if len(galaxylists) > 0: + for gl in galaxylists: + db.delete(gl) + + if len(pointings) > 0: + for p in pointings: + db.delete(p) + + if len(pointing_events) > 0: + for pe in pointing_events: + db.delete(pe) + + if len(gwalerts) > 0: + for ga in gwalerts: + db.delete(ga) + + # Delete files from storage + try: + objects = list_gwtm_bucket( + container="test", source=settings.STORAGE_BUCKET_SOURCE, config=settings + ) + objects_to_delete = [ + o + for o in objects + if not any(t in o for t in testids) + and "alert.json" not in o + and o != "test/" + ] + + if len(objects_to_delete): + total = 0 + for items in by_chunk(objects_to_delete, 1000): + total += len(items) + delete_gwtm_files( + keys=items, source=settings.STORAGE_BUCKET_SOURCE, config=settings + ) + except Exception as e: + # Log the error but continue with the database changes + print(f"Error deleting files: {str(e)}") + + # Commit all changes + db.commit() + + return {"message": "Successfully deleted test alerts and associated data"} diff --git a/server/routes/gw_alert/get_contour.py b/server/routes/gw_alert/get_contour.py new file mode 100644 index 00000000..3a1f0ae3 --- /dev/null +++ b/server/routes/gw_alert/get_contour.py @@ -0,0 +1,64 @@ +"""Get GW contour endpoint.""" + +from fastapi import APIRouter, Depends, Query +from fastapi.openapi.models import Response +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception +from server.utils.gwtm_io import download_gwtm_file +from server.config import Settings as settings + +router = APIRouter(tags=["gw_alerts"]) + + +@router.get("/gw_contour") +async def get_gw_contour( + graceid: str = Query(..., description="Grace ID of the GW event"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get the contour for a GW alert. + + Parameters: + - graceid: The Grace ID of the GW event + + Returns the contour JSON file + """ + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(graceid) + + # Get the latest alert for this graceid + alerts = ( + db.query(GWAlert) + .filter(GWAlert.graceid == graceid) + .order_by(GWAlert.datecreated.desc()) + .all() + ) + + if not alerts: + raise not_found_exception(f"No alert found with graceid: {graceid}") + + # Extract alert info + alert = alerts[0] + alert_types = [x.alert_type for x in alerts] + latest_alert_type = alert.alert_type + num = len([x for x in alert_types if x == latest_alert_type]) - 1 + alert_type = latest_alert_type if num < 1 else latest_alert_type + str(num) + + # Build path info + path_info = f"{graceid}-{alert_type}" + contour_path = f"fit/{path_info}-contours-smooth.json" + + try: + file_content = download_gwtm_file( + filename=contour_path, + source=settings.STORAGE_BUCKET_SOURCE, + config=settings, + ) + return Response(content=file_content, media_type="application/json") + except Exception as e: + raise not_found_exception(f"Error in retrieving Contour file: {contour_path}") diff --git a/server/routes/gw_alert/get_grb_moc.py b/server/routes/gw_alert/get_grb_moc.py new file mode 100644 index 00000000..32f1bc7d --- /dev/null +++ b/server/routes/gw_alert/get_grb_moc.py @@ -0,0 +1,57 @@ +"""Get GRB MOC file endpoint.""" + +from fastapi import APIRouter, Depends, Query +from fastapi.openapi.models import Response +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception, validation_exception +from server.utils.gwtm_io import download_gwtm_file +from server.config import Settings as settings + +router = APIRouter(tags=["gw_alerts"]) + + +@router.get("/grb_moc_file") +async def get_grbmoc( + graceid: str = Query(..., description="Grace ID of the GW event"), + instrument: str = Query(..., description="Instrument name (gbm, lat, or bat)"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get the GRB MOC file for a GW alert. + + Parameters: + - graceid: The Grace ID of the GW event + - instrument: Instrument name (gbm, lat, or bat) + + Returns the MOC file + """ + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(graceid) + + # Validate instrument + instrument = instrument.lower() + if instrument not in ["gbm", "lat", "bat"]: + raise validation_exception("Valid instruments are in ['gbm', 'lat', 'bat']") + + # Map instrument names to their full names + instrument_dictionary = {"gbm": "Fermi", "lat": "LAT", "bat": "BAT"} + + # Build path + moc_filepath = f"fit/{graceid}-{instrument_dictionary[instrument]}.json" + + try: + file_content = download_gwtm_file( + filename=moc_filepath, + source=settings.STORAGE_BUCKET_SOURCE, + config=settings, + ) + return Response(content=file_content, media_type="application/json") + except Exception as e: + raise not_found_exception( + f"MOC file for GW-Alert: '{graceid}' and instrument: '{instrument}' does not exist!" + ) diff --git a/server/routes/gw_alert/get_skymap.py b/server/routes/gw_alert/get_skymap.py new file mode 100644 index 00000000..bef81b98 --- /dev/null +++ b/server/routes/gw_alert/get_skymap.py @@ -0,0 +1,88 @@ +"""Get GW skymap endpoint.""" + +import io +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception +from server.utils.gwtm_io import download_gwtm_file +from server.config import Settings as settings + +router = APIRouter(tags=["gw_alerts"]) + + +@router.get( + "/gw_skymap", + response_description="FITS file containing the gravitational wave skymap", + responses={ + 200: { + "content": {"application/fits": {}}, + "description": "The skymap FITS file for the specified gravitational wave event", + }, + 404: {"description": "Skymap not found for the specified event"}, + }, +) +async def get_gw_skymap( + graceid: str = Query(..., description="Grace ID of the GW event"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get the skymap FITS file for a gravitational wave alert. + + Parameters: + - graceid: The Grace ID of the GW event + + Returns: + - A binary response containing the FITS file with the skymap data + """ + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(graceid) + + # Get the latest alert for this graceid + alerts = ( + db.query(GWAlert) + .filter(GWAlert.graceid == graceid) + .order_by(GWAlert.datecreated.desc()) + .all() + ) + + if not alerts: + raise not_found_exception(f"No alert found with graceid: {graceid}") + + # Extract alert info + alert = alerts[0] + alert_types = [x.alert_type for x in alerts] + latest_alert_type = alert.alert_type + num = len([x for x in alert_types if x == latest_alert_type]) - 1 + alert_type = latest_alert_type if num < 1 else latest_alert_type + str(num) + + # Build path info + path_info = f"{graceid}-{alert_type}" + skymap_path = f"fit/{path_info}.fits.gz" + + # Download and return the file + try: + file_content = download_gwtm_file( + filename=skymap_path, + source=settings.STORAGE_BUCKET_SOURCE, + config=settings, + decode=False, + ) + + # Create a streaming response with the binary content + filename = f"{graceid}_skymap.fits.gz" + return StreamingResponse( + io.BytesIO(file_content), + media_type="application/fits", + headers={ + "Content-Disposition": f"attachment; filename={filename}", + "Content-Type": "application/fits", + }, + ) + except Exception as e: + raise not_found_exception(f"Error in retrieving skymap file: {skymap_path}") diff --git a/server/routes/gw_alert/post_alert.py b/server/routes/gw_alert/post_alert.py new file mode 100644 index 00000000..af805c68 --- /dev/null +++ b/server/routes/gw_alert/post_alert.py @@ -0,0 +1,33 @@ +"""Post GW alert endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.schemas.gw_alert import GWAlertSchema +from server.auth.auth import verify_admin + +router = APIRouter(tags=["gw_alerts"]) + + +@router.post("/post_alert", response_model=GWAlertSchema) +async def post_alert( + alert_data: GWAlertSchema, + db: Session = Depends(get_db), + user=Depends(verify_admin), # Only admin can post alerts +): + """ + Post a new GW alert (admin only). + + Parameters: + - Alert data in the request body + + Returns the created GW Alert object + """ + alert_instance = GWAlert(**alert_data.dict()) + db.add(alert_instance) + db.commit() + db.refresh(alert_instance) + + return alert_instance diff --git a/server/routes/gw_alert/query_alerts.py b/server/routes/gw_alert/query_alerts.py new file mode 100644 index 00000000..9ed52bf4 --- /dev/null +++ b/server/routes/gw_alert/query_alerts.py @@ -0,0 +1,52 @@ +"""Query GW alerts endpoint.""" + +from typing import List, Optional +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.schemas.gw_alert import GWAlertSchema +from server.auth.auth import get_current_user + +router = APIRouter(tags=["gw_alerts"]) + + +@router.get("/query_alerts", response_model=List[GWAlertSchema]) +async def query_alerts( + graceid: Optional[str] = None, + alert_type: Optional[str] = None, + role: Optional[str] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Query GW alerts with optional filters. + + Parameters: + - graceid: Filter by Grace ID + - alert_type: Filter by alert type + + Returns a list of GW Alert objects + """ + filter_conditions = [] + + if graceid: + # Handle alternative GraceID format if needed + # Implementation will depend on the graceidfromalternate function + filter_conditions.append(GWAlert.graceid == graceid) + + if alert_type: + filter_conditions.append(GWAlert.alert_type == alert_type) + + if role: + filter_conditions.append(GWAlert.role == role) + + alerts = ( + db.query(GWAlert) + .filter(*filter_conditions) + .order_by(GWAlert.datecreated.desc()) + .all() + ) + + return alerts diff --git a/server/routes/gw_alert/router.py b/server/routes/gw_alert/router.py new file mode 100644 index 00000000..4cdb2363 --- /dev/null +++ b/server/routes/gw_alert/router.py @@ -0,0 +1,22 @@ +"""Consolidated router for all GW alert endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .query_alerts import router as query_alerts_router +from .post_alert import router as post_alert_router +from .get_skymap import router as get_skymap_router +from .get_contour import router as get_contour_router +from .get_grb_moc import router as get_grb_moc_router +from .delete_test_alerts import router as delete_test_alerts_router + +# Create the main router that includes all GW alert routes +router = APIRouter(tags=["gw_alerts"]) + +# Include all the individual routers +router.include_router(query_alerts_router) +router.include_router(post_alert_router) +router.include_router(get_skymap_router) +router.include_router(get_contour_router) +router.include_router(get_grb_moc_router) +router.include_router(delete_test_alerts_router) diff --git a/server/routes/gw_galaxy/__init__.py b/server/routes/gw_galaxy/__init__.py new file mode 100644 index 00000000..fc0fdc80 --- /dev/null +++ b/server/routes/gw_galaxy/__init__.py @@ -0,0 +1 @@ +"""GW galaxy route module.""" diff --git a/server/routes/gw_galaxy/get_event_galaxies.py b/server/routes/gw_galaxy/get_event_galaxies.py new file mode 100644 index 00000000..5c7bb784 --- /dev/null +++ b/server/routes/gw_galaxy/get_event_galaxies.py @@ -0,0 +1,107 @@ +"""Get event galaxies endpoint.""" + +from typing import List, Optional +import datetime +from dateutil.parser import parse as date_parse +from fastapi import APIRouter, Depends, Query +from geoalchemy2.shape import to_shape +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.db.models.gw_galaxy import GWGalaxyEntry, GWGalaxyList +from server.auth.auth import get_current_user +from server.schemas.gw_galaxy import GWGalaxyEntrySchema +from server.utils.error_handling import validation_exception + +router = APIRouter(tags=["galaxies"]) + + +@router.get("/event_galaxies", response_model=List[GWGalaxyEntrySchema]) +async def get_event_galaxies( + graceid: str = Query(..., description="Grace ID of the GW event"), + timesent_stamp: Optional[str] = None, + listid: Optional[int] = None, + groupname: Optional[str] = None, + score_gt: Optional[float] = None, + score_lt: Optional[float] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get galaxies associated with a GW event. + """ + filter_conditions = [GWGalaxyEntry.listid == GWGalaxyList.id] + + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(graceid) + filter_conditions.append(GWGalaxyList.graceid == graceid) + + if timesent_stamp: + try: + time = date_parse(timesent_stamp) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=[ + f"Timestamp should be in %Y-%m-%dT%H:%M:%S.%f format. e.g. 2019-05-01T12:00:00.00" + ], + ) + + # Find the alert with the given time and graceid + alert = ( + db.query(GWAlert) + .filter( + GWAlert.timesent < time + datetime.timedelta(seconds=15), + GWAlert.timesent > time - datetime.timedelta(seconds=15), + GWAlert.graceid == graceid, + ) + .first() + ) + + if not alert: + raise validation_exception( + message=f"Invalid 'timesent_stamp' for event {graceid}", + errors=[ + f"Please visit http://treasuremap.space/alerts?graceids={graceid} for valid timesent stamps for this event" + ], + ) + + filter_conditions.append(GWGalaxyList.alertid == str(alert.id)) + + if listid: + filter_conditions.append(GWGalaxyList.id == listid) + if groupname: + filter_conditions.append(GWGalaxyList.groupname == groupname) + if score_gt is not None: + filter_conditions.append(GWGalaxyEntry.score >= score_gt) + if score_lt is not None: + filter_conditions.append(GWGalaxyEntry.score <= score_lt) + + galaxy_entries = ( + db.query(GWGalaxyEntry) + .join(GWGalaxyList, GWGalaxyList.id == GWGalaxyEntry.listid) + .filter(*filter_conditions) + .all() + ) + + # Convert GeoAlchemy2 Geography to a string for Pydantic + result_entries = [] + for entry in galaxy_entries: + entry_dict = { + "id": entry.id, + "listid": entry.listid, + "name": entry.name, + "score": entry.score, + "rank": entry.rank, + "info": entry.info, + } + + # Convert position to WKT string + if entry.position: + shape = to_shape(entry.position) + entry_dict["position"] = str(shape) + + result_entries.append(GWGalaxyEntrySchema(**entry_dict)) + + return result_entries diff --git a/server/routes/gw_galaxy/get_glade.py b/server/routes/gw_galaxy/get_glade.py new file mode 100644 index 00000000..f1e34e59 --- /dev/null +++ b/server/routes/gw_galaxy/get_glade.py @@ -0,0 +1,78 @@ +"""Get GLADE galaxies endpoint.""" + +from typing import Optional +from fastapi import APIRouter, Depends, Query +from geoalchemy2.shape import to_shape +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.auth.auth import get_current_user + +router = APIRouter(tags=["galaxies"]) + + +@router.get("/glade") +async def get_galaxies( + ra: Optional[float] = Query(None, description="Right ascension"), + dec: Optional[float] = Query(None, description="Declination"), + name: Optional[str] = Query(None, description="Galaxy name to search for"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get galaxies from the GLADE catalog. + """ + from server.utils.function import isFloat + from server.db.models.glade import Glade2P3 + + filter_conditions = [] + base_filter = [ + Glade2P3.pgc_number != -1, + Glade2P3.distance > 0, + Glade2P3.distance < 100, + ] + + # Create base query + query = db.query(Glade2P3).filter(*base_filter) + + # Handle orderby for positioning + orderby = [] + + # Handle ra and dec + if ra is not None and dec is not None and isFloat(ra) and isFloat(dec): + from sqlalchemy import func + + geom = f"SRID=4326;POINT({ra} {dec})" + orderby.append(func.ST_Distance(Glade2P3.position, geom)) + + # Handle name search + if name: + from sqlalchemy import or_ + + or_conditions = [ + Glade2P3._2mass_name.contains(name.strip()), + Glade2P3.gwgc_name.contains(name.strip()), + Glade2P3.hyperleda_name.contains(name.strip()), + Glade2P3.sdssdr12_name.contains(name.strip()), + ] + filter_conditions.append(or_(*or_conditions)) + + # Execute query + galaxies = query.filter(*filter_conditions).order_by(*orderby).limit(15).all() + + # Parse galaxies to dict format + result = [] + for galaxy in galaxies: + # Convert to dict + galaxy_dict = { + c.name: getattr(galaxy, c.name) for c in galaxy.__table__.columns + } + + # Convert position to WKT string if it exists + if galaxy.position: + shape = to_shape(galaxy.position) + galaxy_dict["position"] = str(shape) + + result.append(galaxy_dict) + + return result diff --git a/server/routes/gw_galaxy/post_event_galaxies.py b/server/routes/gw_galaxy/post_event_galaxies.py new file mode 100644 index 00000000..06efda58 --- /dev/null +++ b/server/routes/gw_galaxy/post_event_galaxies.py @@ -0,0 +1,196 @@ +"""Post event galaxies endpoint.""" + +import datetime +from dateutil.parser import parse as date_parse +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert +from server.db.models.gw_galaxy import GWGalaxyList, GWGalaxyEntry +from server.auth.auth import get_current_user +from server.schemas.gw_galaxy import PostEventGalaxiesRequest, PostEventGalaxiesResponse +from server.utils.error_handling import validation_exception + +router = APIRouter(tags=["galaxies"]) + + +@router.post("/event_galaxies", response_model=PostEventGalaxiesResponse) +async def post_event_galaxies( + request: PostEventGalaxiesRequest, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Post galaxies associated with a GW event. + """ + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(request.graceid) + + # Parse timesent_stamp + try: + print(f"Parsing timesent_stamp: {request.timesent_stamp}") + time = date_parse(request.timesent_stamp) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=[ + "Timestamp should be in %Y-%m-%dT%H:%M:%S.%f format. e.g. 2019-05-01T12:00:00.00" + ], + ) + + # Find the alert + alert = ( + db.query(GWAlert) + .filter( + GWAlert.timesent < time + datetime.timedelta(seconds=15), + GWAlert.timesent > time - datetime.timedelta(seconds=15), + GWAlert.graceid == graceid, + ) + .first() + ) + + if not alert: + raise validation_exception( + message=f"Invalid 'timesent_stamp' for event {graceid}", + errors=[ + f"Please visit http://treasuremap.space/alerts?graceids={graceid} for valid timesent stamps for this event" + ], + ) + + # Handle groupname - default to username if not provided + groupname = request.groupname or user.username + + # Handle DOI creators + post_doi = request.request_doi + doi_string = ". " + creators = None + + if post_doi: + if request.creators: + creators = [] + for c in request.creators: + # Since request.creators is List[DOICreator] from Pydantic, + # c is a DOICreator object, not a dict + if not c.name or not c.affiliation: + raise validation_exception( + message="Invalid DOI creator information", + errors=[ + "name and affiliation are required for each creator in the list" + ], + ) + creator_dict = {"name": c.name, "affiliation": c.affiliation} + if c.orcid: + creator_dict["orcid"] = c.orcid + if c.gnd: + creator_dict["gnd"] = c.gnd + creators.append(creator_dict) + elif request.doi_group_id: + from server.db.models.doi_author import DOIAuthor + + valid, creators_list = DOIAuthor.construct_creators( + request.doi_group_id, user.id, db + ) + if not valid: + raise validation_exception( + message="Invalid DOI group ID", + errors=["Make sure you are the User associated with the DOI group"], + ) + creators = creators_list + else: + creators = [ + {"name": f"{user.firstname} {user.lastname}", "affiliation": ""} + ] + + # Create galaxy list + gw_galaxy_list = GWGalaxyList( + submitterid=user.id, + graceid=graceid, + alertid=str(alert.id), + groupname=groupname, + reference=request.reference, + ) + db.add(gw_galaxy_list) + db.flush() + + # Process galaxies + valid_galaxies = [] + errors = [] + warnings = [] + + for galaxy_entry in request.galaxies: + try: + # Create the galaxy entry from the Pydantic model + gw_galaxy_entry = GWGalaxyEntry( + listid=gw_galaxy_list.id, + name=galaxy_entry.name, + score=galaxy_entry.score, + rank=galaxy_entry.rank, + info=galaxy_entry.info, + ) + + # Handle position data - use position string if provided, otherwise build from ra/dec + if galaxy_entry.position: + if ( + all(x in galaxy_entry.position for x in ["POINT", "(", ")", " "]) + and "," not in galaxy_entry.position + ): + gw_galaxy_entry.position = galaxy_entry.position + else: + errors.append( + [ + f"Object: {galaxy_entry.dict()}", + [ + 'Invalid position argument. Must be geometry type "POINT(RA DEC)"' + ], + ] + ) + continue + elif galaxy_entry.ra is not None and galaxy_entry.dec is not None: + gw_galaxy_entry.position = ( + f"POINT({galaxy_entry.ra} {galaxy_entry.dec})" + ) + else: + errors.append( + [ + f"Object: {galaxy_entry.dict()}", + [ + "Position data is required. Provide either position or ra/dec coordinates." + ], + ] + ) + continue + + # All validation passed, add to database + db.add(gw_galaxy_entry) + valid_galaxies.append(gw_galaxy_entry) + + except Exception as e: + errors.append([f"Object: {galaxy_entry.dict()}", [str(e)]]) + + db.flush() + + # Handle DOI if requested + if post_doi and valid_galaxies: + from server.utils.function import create_galaxy_score_doi + + doi_id, url = create_galaxy_score_doi( + valid_galaxies, creators, request.reference, graceid, alert.alert_type + ) + + if url is None and doi_id is not None: + errors.append( + "There was an error with the DOI request. Please ensure that author group's ORIC/GND values are accurate" + ) + else: + gw_galaxy_list.doi_id = doi_id + gw_galaxy_list.doi_url = url + doi_string = f". DOI url: {url}." + + db.commit() + + return PostEventGalaxiesResponse( + message=f"Successful adding of {len(valid_galaxies)} galaxies for event {graceid}{doi_string} List ID: {gw_galaxy_list.id}", + errors=errors, + warnings=warnings, + ) diff --git a/server/routes/gw_galaxy/remove_event_galaxies.py b/server/routes/gw_galaxy/remove_event_galaxies.py new file mode 100644 index 00000000..d8013f34 --- /dev/null +++ b/server/routes/gw_galaxy/remove_event_galaxies.py @@ -0,0 +1,46 @@ +"""Remove event galaxies endpoint.""" + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_galaxy import GWGalaxyList, GWGalaxyEntry +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception, permission_exception + +router = APIRouter(tags=["galaxies"]) + + +@router.delete("/remove_event_galaxies") +async def remove_event_galaxies( + listid: int = Query(..., description="ID of the galaxy list to remove"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Remove galaxies associated with a GW event. + """ + # Find galaxy list + galaxy_list = db.query(GWGalaxyList).filter(GWGalaxyList.id == listid).first() + + if not galaxy_list: + raise not_found_exception("No galaxies found with that list ID") + + # Check permissions + if user.id != galaxy_list.submitterid: + raise permission_exception( + "You can only delete information related to your API token" + ) + + # Find and delete galaxy entries + galaxy_entries = ( + db.query(GWGalaxyEntry).filter(GWGalaxyEntry.listid == listid).all() + ) + + for entry in galaxy_entries: + db.delete(entry) + + db.delete(galaxy_list) + db.commit() + + return {"message": "Successfully deleted your galaxy list"} diff --git a/server/routes/gw_galaxy/router.py b/server/routes/gw_galaxy/router.py new file mode 100644 index 00000000..36551e72 --- /dev/null +++ b/server/routes/gw_galaxy/router.py @@ -0,0 +1,18 @@ +"""Consolidated router for all GW galaxy endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .get_event_galaxies import router as get_event_galaxies_router +from .post_event_galaxies import router as post_event_galaxies_router +from .remove_event_galaxies import router as remove_event_galaxies_router +from .get_glade import router as get_glade_router + +# Create the main router that includes all GW galaxy routes +router = APIRouter(tags=["galaxies"]) + +# Include all the individual routers +router.include_router(get_event_galaxies_router) +router.include_router(post_event_galaxies_router) +router.include_router(remove_event_galaxies_router) +router.include_router(get_glade_router) diff --git a/server/routes/icecube/__init__.py b/server/routes/icecube/__init__.py new file mode 100644 index 00000000..91d439b7 --- /dev/null +++ b/server/routes/icecube/__init__.py @@ -0,0 +1 @@ +"""IceCube route module.""" diff --git a/server/routes/icecube/post_icecube_notice.py b/server/routes/icecube/post_icecube_notice.py new file mode 100644 index 00000000..0d35892c --- /dev/null +++ b/server/routes/icecube/post_icecube_notice.py @@ -0,0 +1,83 @@ +"""Post IceCube notice endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import Dict, Any +from datetime import datetime + +from server.db.database import get_db +from server.db.models.icecube import IceCubeNotice, IceCubeNoticeCoincEvent +from server.schemas.icecube import ( + IceCubeNoticeSchema, + IceCubeNoticeCoincEventSchema, + IceCubeNoticeRequestSchema, +) +from server.auth.auth import verify_admin + +router = APIRouter(tags=["icecube"]) + + +@router.post("/post_icecube_notice", response_model=Dict[str, Any]) +async def post_icecube_notice( + request: IceCubeNoticeRequestSchema, + db: Session = Depends(get_db), + user=Depends(verify_admin), # Only admin can post IceCube notices +): + """ + Post an IceCube neutrino notice (admin only). + + Parameters: + - request: IceCube notice request containing notice_data and events_data + + Returns the created notice and events + """ + # Extract data from request + notice_data = request.notice_data + events_data = request.events_data + + # Check if notice already exists + existing_notice = ( + db.query(IceCubeNotice) + .filter(IceCubeNotice.ref_id == notice_data.ref_id) + .first() + ) + + if existing_notice: + return { + "icecube_notice": {"message": "event already exists"}, + "icecube_notice_events": [], + } + + # Set required fields that might not be in the input data + notice_dict = notice_data.model_dump() + notice_dict["datecreated"] = datetime.now() + + # Create the notice object + notice = IceCubeNotice(**notice_dict) + db.add(notice) + db.flush() # Flush to get the generated ID + + # Process events + events = [] + for event_data in events_data: + # Get the event data + event_dict = event_data.model_dump() + + # Set the notice ID and creation date + event_dict["icecube_notice_id"] = notice.id + event_dict["datecreated"] = datetime.now() + + # Create the event object + event = IceCubeNoticeCoincEvent(**event_dict) + db.add(event) + events.append(event) + + db.commit() + + # Convert SQLAlchemy models to Pydantic schemas for serialization + notice_schema = IceCubeNoticeSchema.model_validate(notice) + events_schemas = [ + IceCubeNoticeCoincEventSchema.model_validate(event) for event in events + ] + + return {"icecube_notice": notice_schema, "icecube_notice_events": events_schemas} diff --git a/server/routes/icecube/router.py b/server/routes/icecube/router.py new file mode 100644 index 00000000..03c05850 --- /dev/null +++ b/server/routes/icecube/router.py @@ -0,0 +1,12 @@ +"""Consolidated router for all IceCube endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .post_icecube_notice import router as post_icecube_notice_router + +# Create the main router that includes all IceCube routes +router = APIRouter(tags=["icecube"]) + +# Include all the individual routers +router.include_router(post_icecube_notice_router) diff --git a/server/routes/instrument/__init__.py b/server/routes/instrument/__init__.py new file mode 100644 index 00000000..b605dff3 --- /dev/null +++ b/server/routes/instrument/__init__.py @@ -0,0 +1 @@ +"""Instrument-related routes.""" diff --git a/server/routes/instrument/create_footprint.py b/server/routes/instrument/create_footprint.py new file mode 100644 index 00000000..93423581 --- /dev/null +++ b/server/routes/instrument/create_footprint.py @@ -0,0 +1,59 @@ +"""Create footprint endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from geoalchemy2 import WKBElement +from shapely.wkb import loads as wkb_loads + +from server.db.database import get_db +from server.db.models.instrument import Instrument, FootprintCCD +from server.schemas.instrument import FootprintCCDCreate, FootprintCCDSchema +from server.auth.auth import get_current_user +from server.utils.error_handling import not_found_exception, permission_exception + +router = APIRouter(tags=["instruments"]) + + +@router.post("/footprints", response_model=FootprintCCDSchema) +async def create_footprint( + footprint: FootprintCCDCreate, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Create a new footprint for an instrument. + + Parameters: + - footprint: Footprint data + + Returns the created footprint + """ + # Check if the instrument exists + instrument = ( + db.query(Instrument).filter(Instrument.id == footprint.instrumentid).first() + ) + if not instrument: + raise not_found_exception( + f"Instrument with ID {footprint.instrumentid} not found" + ) + + # Check permissions (only the instrument submitter can add footprints) + if instrument.submitterid != user.id: + raise permission_exception( + "You don't have permission to add footprints to this instrument" + ) + + # Create a new footprint + new_footprint = FootprintCCD( + instrumentid=footprint.instrumentid, footprint=footprint.footprint # WKT format + ) + + db.add(new_footprint) + db.commit() + db.refresh(new_footprint) + + # Convert the footprint from WKB to WKT for the response + if isinstance(new_footprint.footprint, WKBElement): + new_footprint.footprint = str(wkb_loads(bytes(new_footprint.footprint.data))) + + return new_footprint diff --git a/server/routes/instrument/create_instrument.py b/server/routes/instrument/create_instrument.py new file mode 100644 index 00000000..7bdaf7d6 --- /dev/null +++ b/server/routes/instrument/create_instrument.py @@ -0,0 +1,42 @@ +"""Create instrument endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +import datetime + +from server.db.database import get_db +from server.db.models.instrument import Instrument +from server.schemas.instrument import InstrumentCreate, InstrumentSchema +from server.auth.auth import get_current_user + +router = APIRouter(tags=["instruments"]) + + +@router.post("/instruments", response_model=InstrumentSchema) +async def create_instrument( + instrument: InstrumentCreate, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Create a new instrument. + + Parameters: + - instrument: Instrument data + + Returns the created instrument + """ + # Create a new instrument + new_instrument = Instrument( + instrument_name=instrument.instrument_name, + nickname=instrument.nickname, + instrument_type=instrument.instrument_type, + submitterid=user.id, + datecreated=datetime.datetime.now(), + ) + + db.add(new_instrument) + db.commit() + db.refresh(new_instrument) + + return new_instrument diff --git a/server/routes/instrument/get_footprints.py b/server/routes/instrument/get_footprints.py new file mode 100644 index 00000000..7a558bce --- /dev/null +++ b/server/routes/instrument/get_footprints.py @@ -0,0 +1,64 @@ +"""Get footprints endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from sqlalchemy import or_ +from typing import List, Optional +from geoalchemy2 import WKBElement +from shapely.wkb import loads as wkb_loads + +from server.db.database import get_db +from server.db.models.instrument import Instrument, FootprintCCD +from server.schemas.instrument import FootprintCCDSchema +from server.auth.auth import get_current_user + +router = APIRouter(tags=["instruments"]) + + +@router.get("/footprints", response_model=List[FootprintCCDSchema]) +async def get_footprints( + id: Optional[int] = None, + name: Optional[str] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get instrument footprints with optional filters. + + Parameters: + - id: Filter by instrument ID + - name: Filter by instrument name (fuzzy match) + + Returns a list of footprint objects + """ + filter_conditions = [] + + if id: + filter_conditions.append(FootprintCCD.instrumentid == id) + + if name: + filter_conditions.append(FootprintCCD.instrumentid == Instrument.id) + + or_conditions = [] + or_conditions.append(Instrument.instrument_name.contains(name.strip())) + or_conditions.append(Instrument.nickname.contains(name.strip())) + + filter_conditions.append(or_(*or_conditions)) + + # When filtering by name, we need to join with Instrument table + footprints = ( + db.query(FootprintCCD) + .join(Instrument, FootprintCCD.instrumentid == Instrument.id) + .filter(*filter_conditions) + .all() + ) + else: + footprints = db.query(FootprintCCD).filter(*filter_conditions).all() + + # Convert WKB to WKT for the `footprint` field + for footprint in footprints: + if isinstance(footprint.footprint, WKBElement): + footprint.footprint = str(wkb_loads(bytes(footprint.footprint.data))) + + # FastAPI will automatically convert SQLAlchemy models to Pydantic models + return footprints diff --git a/server/routes/instrument/get_instruments.py b/server/routes/instrument/get_instruments.py new file mode 100644 index 00000000..2a139802 --- /dev/null +++ b/server/routes/instrument/get_instruments.py @@ -0,0 +1,82 @@ +"""Get instruments endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from sqlalchemy import or_ +from typing import List, Optional +import json + +from server.db.database import get_db +from server.db.models.instrument import Instrument +from server.db.models.pointing import Pointing +from server.schemas.instrument import InstrumentSchema +from server.auth.auth import get_current_user +from server.utils.error_handling import validation_exception +from server.core.enums.instrumenttype import InstrumentType + +router = APIRouter(tags=["instruments"]) + + +@router.get("/instruments", response_model=List[InstrumentSchema]) +async def get_instruments( + id: Optional[int] = None, + ids: Optional[str] = None, + name: Optional[str] = None, + names: Optional[str] = None, + type: Optional[InstrumentType] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Get instruments with optional filters. + + Parameters: + - id: Filter by instrument ID + - ids: Filter by list of instrument IDs + - name: Filter by instrument name (fuzzy match) + - names: Filter by list of instrument names (fuzzy match) + - type: Filter by instrument type + + Returns a list of instrument objects + """ + filter_conditions = [] + + if id: + filter_conditions.append(Instrument.id == id) + + if ids: + try: + if isinstance(ids, str): + ids_list = json.loads(ids) + else: + ids_list = ids + filter_conditions.append(Instrument.id.in_(ids_list)) + except: + raise validation_exception("Invalid ids format. Must be a JSON array.") + + if name: + filter_conditions.append(Instrument.instrument_name.contains(name)) + + if names: + try: + if isinstance(names, str): + insts = json.loads(names) + else: + insts = names + + or_conditions = [] + for i in insts: + or_conditions.append(Instrument.instrument_name.contains(i.strip())) + + filter_conditions.append(or_(*or_conditions)) + filter_conditions.append(Instrument.id == Pointing.instrumentid) + except: + raise validation_exception("Invalid names format. Must be a JSON array.") + + if type: + filter_conditions.append(Instrument.instrument_type == type) + + instruments = db.query(Instrument).filter(*filter_conditions).all() + + # FastAPI will automatically convert SQLAlchemy models to Pydantic models + return instruments diff --git a/server/routes/instrument/router.py b/server/routes/instrument/router.py new file mode 100644 index 00000000..42f90536 --- /dev/null +++ b/server/routes/instrument/router.py @@ -0,0 +1,18 @@ +"""Consolidated router for all instrument endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .get_instruments import router as get_instruments_router +from .get_footprints import router as get_footprints_router +from .create_instrument import router as create_instrument_router +from .create_footprint import router as create_footprint_router + +# Create the main router that includes all instrument routes +router = APIRouter(tags=["instruments"]) + +# Include all the individual routers +router.include_router(get_instruments_router) +router.include_router(get_footprints_router) +router.include_router(create_instrument_router) +router.include_router(create_footprint_router) diff --git a/server/routes/pointing/__init__.py b/server/routes/pointing/__init__.py new file mode 100644 index 00000000..2ab0b353 --- /dev/null +++ b/server/routes/pointing/__init__.py @@ -0,0 +1 @@ +"""Pointing-related routes.""" diff --git a/server/routes/pointing/cancel_all.py b/server/routes/pointing/cancel_all.py new file mode 100644 index 00000000..fda10a34 --- /dev/null +++ b/server/routes/pointing/cancel_all.py @@ -0,0 +1,54 @@ +"""Cancel all pointings endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from datetime import datetime + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.schemas.pointing import CancelAllRequest +from server.auth.auth import get_current_user +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.utils import pointing as pointing_utils + +router = APIRouter(tags=["pointings"]) + + +@router.post("/cancel_all") +async def cancel_all( + request: CancelAllRequest, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Cancel all planned pointings for a specific GW event and instrument. + """ + # Validate instrument exists + pointing_utils.validate_instrument(request.instrumentid, db) + + # Validate graceid exists + normalized_graceid = GWAlert.graceidfromalternate(request.graceid) + + # Build the filter + filter_conditions = [ + Pointing.status == pointing_status_enum.planned, + Pointing.submitterid == user.id, + Pointing.instrumentid == request.instrumentid, + Pointing.id == PointingEvent.pointingid, + PointingEvent.graceid == normalized_graceid, + ] + + # Query the pointings + pointings = db.query(Pointing).filter(*filter_conditions) + pointing_count = pointings.count() + + # Update the status + for pointing in pointings: + pointing.status = pointing_status_enum.cancelled + pointing.dateupdated = datetime.now() + + db.commit() + + return {"message": f"Updated {pointing_count} Pointings successfully"} diff --git a/server/routes/pointing/create_pointings.py b/server/routes/pointing/create_pointings.py new file mode 100644 index 00000000..75b1e6e0 --- /dev/null +++ b/server/routes/pointing/create_pointings.py @@ -0,0 +1,134 @@ +"""Create pointings endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.schemas.pointing import PointingCreateRequest, PointingResponse +from server.auth.auth import get_current_user +from server.utils import pointing as pointing_utils + +router = APIRouter(tags=["pointings"]) + + +@router.post("/pointings", response_model=PointingResponse) +async def add_pointings( + request: PointingCreateRequest, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Add new pointings to the database. + """ + # Initialize variables + points = [] + errors = [] + warnings = [] + + # Validate graceid exists + pointing_utils.validate_graceid(request.graceid, db) + + # Prepare DOI creators if DOI is requested + creators = None + if request.request_doi: + creators = pointing_utils.prepare_doi_creators( + request.creators, request.doi_group_id, user, db + ) + + # Get instruments for validation + instruments_dict = pointing_utils.get_instruments_dict(db) + + # Get existing pointings for duplicate check + existing_pointings = ( + db.query(Pointing) + .filter( + Pointing.id == PointingEvent.pointingid, + PointingEvent.graceid == request.graceid, + ) + .all() + ) + + # Process pointings (either single or multiple) + pointings_to_process = [] + if request.pointing: + pointings_to_process = [request.pointing] + elif request.pointings: + pointings_to_process = request.pointings + + for pointing_data in pointings_to_process: + try: + # Check if this is an update to a planned pointing + if hasattr(pointing_data, "id") and pointing_data.id: + # Handle planned pointing update + pointing_obj = pointing_utils.handle_planned_pointing_update( + pointing_data, user.id, db + ) + else: + # Validate and resolve instrument reference + instrument_id = pointing_utils.validate_instrument_reference( + pointing_data, instruments_dict + ) + + # Create new pointing object + pointing_obj = pointing_utils.create_pointing_from_schema( + pointing_data, user.id, instrument_id + ) + + # Check for duplicates + if pointing_utils.check_duplicate_pointing( + pointing_obj, existing_pointings + ): + errors.append( + [ + f"Object: {pointing_data.dict()}", + ["Pointing already submitted"], + ] + ) + continue + + points.append(pointing_obj) + db.add(pointing_obj) + + except Exception as e: + errors.append([f"Object: {pointing_data.model_dump()}", [str(e)]]) + + # Flush to get pointing IDs + db.flush() + + # Create pointing events (this should always happen when we have valid points and graceid) + if points: # Only create pointing events if we have valid points + for p in points: + pointing_event = PointingEvent(pointingid=p.id, graceid=request.graceid) + db.add(pointing_event) + + db.flush() + db.commit() + + # Handle DOI creation if requested + doi_url = None + if request.request_doi and points: + if request.doi_url: + doi_id, doi_url = 0, request.doi_url + else: + doi_id, doi_url = pointing_utils.create_doi_for_pointings( + points, request.graceid, creators, db + ) + + if doi_id is not None: + for p in points: + p.doi_url = doi_url + p.doi_id = doi_id + + db.flush() + db.commit() + + # Return response + return PointingResponse( + pointing_ids=[p.id for p in points], + ERRORS=errors, + WARNINGS=warnings, + DOI=doi_url, + ) diff --git a/server/routes/pointing/get_pointings.py b/server/routes/pointing/get_pointings.py new file mode 100644 index 00000000..fbf5cdc8 --- /dev/null +++ b/server/routes/pointing/get_pointings.py @@ -0,0 +1,617 @@ +"""Get pointings endpoint with comprehensive filtering.""" + +from fastapi import APIRouter, Depends, Query, HTTPException +from sqlalchemy.orm import Session +from sqlalchemy import func, or_ +from datetime import datetime +from typing import List, Optional +import json + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.db.models.users import Users +from server.schemas.pointing import PointingSchema +from server.auth.auth import get_current_user +from server.utils.error_handling import validation_exception +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.core.enums.depthunit import DepthUnit as depth_unit_enum +from server.core.enums.bandpass import Bandpass +from server.core.enums.wavelengthunits import WavelengthUnits +from server.core.enums.frequencyunits import FrequencyUnits as frequency_units +from server.core.enums.energyunits import EnergyUnits as energy_units +from server.utils.function import isInt, isFloat + +router = APIRouter(tags=["pointings"]) + + +@router.get("/pointings", response_model=List[PointingSchema]) +def get_pointings( + # Basic filters + graceid: Optional[str] = Query(None, description="Grace ID of the GW event"), + graceids: Optional[str] = Query( + None, description="Comma-separated list or JSON array of Grace IDs" + ), + id: Optional[int] = Query(None, description="Filter by pointing ID"), + ids: Optional[str] = Query( + None, description="Comma-separated list or JSON array of pointing IDs" + ), + # Status filters + status: Optional[str] = Query( + None, description="Filter by status (planned, completed, cancelled)" + ), + statuses: Optional[str] = Query( + None, description="Comma-separated list or JSON array of statuses" + ), + # Time filters + completed_after: Optional[datetime] = Query( + None, description="Filter for pointings completed after this time (ISO format)" + ), + completed_before: Optional[datetime] = Query( + None, description="Filter for pointings completed before this time (ISO format)" + ), + planned_after: Optional[datetime] = Query( + None, description="Filter for pointings planned after this time (ISO format)" + ), + planned_before: Optional[datetime] = Query( + None, description="Filter for pointings planned before this time (ISO format)" + ), + # User filters + user: Optional[str] = Query( + None, description="Filter by username, first name, or last name" + ), + users: Optional[str] = Query( + None, description="Comma-separated list or JSON array of usernames" + ), + # Instrument filters + instrument: Optional[str] = Query( + None, description="Filter by instrument ID or name" + ), + instruments: Optional[str] = Query( + None, + description="Comma-separated list or JSON array of instrument IDs or names", + ), + # Band filters + band: Optional[str] = Query(None, description="Filter by band"), + bands: Optional[str] = Query( + None, description="Comma-separated list or JSON array of bands" + ), + # Spectral filters + wavelength_regime: Optional[str] = Query( + None, description="Filter by wavelength regime [min, max]" + ), + wavelength_unit: Optional[str] = Query( + None, description="Wavelength unit (angstrom, nanometer, micron)" + ), + frequency_regime: Optional[str] = Query( + None, description="Filter by frequency regime [min, max]" + ), + frequency_unit: Optional[str] = Query( + None, description="Frequency unit (Hz, kHz, MHz, GHz, THz)" + ), + energy_regime: Optional[str] = Query( + None, description="Filter by energy regime [min, max]" + ), + energy_unit: Optional[str] = Query( + None, description="Energy unit (eV, keV, MeV, GeV, TeV)" + ), + # Depth filters + depth_gt: Optional[float] = Query( + None, description="Filter by depth greater than this value" + ), + depth_lt: Optional[float] = Query( + None, description="Filter by depth less than this value" + ), + depth_unit: Optional[str] = Query( + None, description="Depth unit (ab_mag, vega_mag, flux_erg, flux_jy)" + ), + # DB access + db: Session = Depends(get_db), + user_auth=Depends(get_current_user), +): + """ + Retrieve pointings from the database with optional filters. + """ + try: + # Build the filter conditions + filter_conditions = [] + + # Handle graceid + if graceid: + # Normalize the graceid + graceid = GWAlert.graceidfromalternate(graceid) + filter_conditions.append(PointingEvent.graceid == graceid) + filter_conditions.append(PointingEvent.pointingid == Pointing.id) + + # Handle graceids + if graceids: + gids = [] + try: + if isinstance(graceids, str): + if "[" in graceids and "]" in graceids: + # Parse as JSON array + gids = json.loads(graceids) + else: + # Parse as comma-separated list + gids = [g.strip() for g in graceids.split(",")] + else: + gids = graceids # Already a list + + normalized_gids = [GWAlert.graceidfromalternate(gid) for gid in gids] + filter_conditions.append(PointingEvent.graceid.in_(normalized_gids)) + filter_conditions.append(PointingEvent.pointingid == Pointing.id) + except Exception as e: + raise validation_exception( + message="Error parsing 'graceids'", + errors=[ + f"Required format is a list: '[graceid1, graceid2...]'", + str(e), + ], + ) + + # Handle ID filters + if id: + if isInt(id): + filter_conditions.append(Pointing.id == int(id)) + else: + raise validation_exception( + message="Invalid ID format", errors=["ID must be an integer"] + ) + + if ids: + try: + id_list = [] + if isinstance(ids, str): + if "[" in ids and "]" in ids: + # Parse as JSON array + id_list = json.loads(ids) + else: + # Parse as comma-separated list + id_list = [ + int(i.strip()) for i in ids.split(",") if isInt(i.strip()) + ] + else: + id_list = ids # Already a list + + filter_conditions.append(Pointing.id.in_(id_list)) + except Exception as e: + raise validation_exception( + message="Error parsing 'ids'", + errors=[f"Required format is a list: '[id1, id2...]'", str(e)], + ) + + # Handle band filters + if band: + for b in Bandpass: + if b.name == band: + filter_conditions.append(Pointing.band == b) + break + else: + raise validation_exception( + message="Invalid band", errors=[f"The band '{band}' is not valid"] + ) + + if bands: + try: + band_list = [] + if isinstance(bands, str): + if "[" in bands and "]" in bands: + # Parse as JSON array + band_list = json.loads(bands) + else: + # Parse as comma-separated list + band_list = [b.strip() for b in bands.split(",")] + else: + band_list = bands # Already a list + + valid_bands = [] + for b in Bandpass: + if b.name in band_list: + valid_bands.append(b) + + if valid_bands: + filter_conditions.append(Pointing.band.in_(valid_bands)) + else: + raise validation_exception( + message="No valid bands", + errors=["No valid bands were specified"], + ) + except Exception as e: + raise validation_exception( + message="Error parsing bands", + errors=[ + f"Invalid format for 'bands' parameter. Required format is a list: '[band1, band2...]'", + str(e), + ], + ) + + # Handle status filters + if status: + if status == "planned": + filter_conditions.append( + Pointing.status == pointing_status_enum.planned + ) + elif status == "completed": + filter_conditions.append( + Pointing.status == pointing_status_enum.completed + ) + elif status == "cancelled": + filter_conditions.append( + Pointing.status == pointing_status_enum.cancelled + ) + else: + raise validation_exception( + message=f"Invalid status: {status}", + errors=["Only 'completed', 'planned', and 'cancelled' are valid."], + ) + + if statuses: + try: + status_list = [] + if isinstance(statuses, str): + if "[" in statuses and "]" in statuses: + # Parse as JSON array + status_list = json.loads(statuses) + else: + # Parse as comma-separated list + status_list = [s.strip() for s in statuses.split(",")] + else: + status_list = statuses # Already a list + + valid_statuses = [] + if "planned" in status_list: + valid_statuses.append(pointing_status_enum.planned) + if "completed" in status_list: + valid_statuses.append(pointing_status_enum.completed) + if "cancelled" in status_list: + valid_statuses.append(pointing_status_enum.cancelled) + + if valid_statuses: + filter_conditions.append(Pointing.status.in_(valid_statuses)) + else: + raise validation_exception( + message="No valid statuses", + errors=["No valid status values were specified"], + ) + except Exception as e: + raise validation_exception( + message="Error parsing statuses", + errors=[ + f"Invalid format for 'statuses' parameter. Required format is a list: '[status1, status2...]'", + str(e), + ], + ) + + # Handle time filters + if completed_after: + try: + filter_conditions.append( + Pointing.status == pointing_status_enum.completed + ) + filter_conditions.append(Pointing.time >= completed_after) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=["Should be ISO format, e.g. 2019-05-01T12:00:00.00"], + ) + + if completed_before: + try: + filter_conditions.append( + Pointing.status == pointing_status_enum.completed + ) + filter_conditions.append(Pointing.time <= completed_before) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=["Should be ISO format, e.g. 2019-05-01T12:00:00.00"], + ) + + if planned_after: + try: + filter_conditions.append( + Pointing.status == pointing_status_enum.planned + ) + filter_conditions.append(Pointing.time >= planned_after) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=["Should be ISO format, e.g. 2019-05-01T12:00:00.00"], + ) + + if planned_before: + try: + filter_conditions.append( + Pointing.status == pointing_status_enum.planned + ) + filter_conditions.append(Pointing.time <= planned_before) + except ValueError: + raise validation_exception( + message="Error parsing date", + errors=["Should be ISO format, e.g. 2019-05-01T12:00:00.00"], + ) + + # Handle user filters + if user: + if isInt(user): + filter_conditions.append(Pointing.submitterid == int(user)) + else: + filter_conditions.append( + or_( + Users.username.contains(user), + Users.firstname.contains(user), + Users.lastname.contains(user), + ) + ) + filter_conditions.append(Users.id == Pointing.submitterid) + + if users: + try: + user_list = [] + if isinstance(users, str): + if "[" in users and "]" in users: + # Parse as JSON array + user_list = json.loads(users) + else: + # Parse as comma-separated list + user_list = [u.strip() for u in users.split(",")] + else: + user_list = users # Already a list + + or_conditions = [] + for u in user_list: + or_conditions.append(Users.username.contains(str(u).strip())) + or_conditions.append(Users.firstname.contains(str(u).strip())) + or_conditions.append(Users.lastname.contains(str(u).strip())) + if isInt(u): + or_conditions.append(Pointing.submitterid == int(u)) + + filter_conditions.append(or_(*or_conditions)) + filter_conditions.append(Users.id == Pointing.submitterid) + except Exception as e: + raise validation_exception( + message="Error parsing 'users'", + errors=[f"Required format is a list: '[user1, user2...]'", str(e)], + ) + + # Handle instrument filters + if instrument: + if isInt(instrument): + filter_conditions.append(Pointing.instrumentid == int(instrument)) + else: + filter_conditions.append( + Instrument.instrument_name.contains(instrument) + ) + filter_conditions.append(Pointing.instrumentid == Instrument.id) + + if instruments: + try: + inst_list = [] + if isinstance(instruments, str): + if "[" in instruments and "]" in instruments: + # Parse as JSON array + inst_list = json.loads(instruments) + else: + # Parse as comma-separated list + inst_list = [i.strip() for i in instruments.split(",")] + else: + inst_list = instruments # Already a list + + or_conditions = [] + for i in inst_list: + or_conditions.append( + Instrument.instrument_name.contains(str(i).strip()) + ) + or_conditions.append(Instrument.nickname.contains(str(i).strip())) + if isInt(i): + or_conditions.append(Pointing.instrumentid == int(i)) + + filter_conditions.append(or_(*or_conditions)) + filter_conditions.append(Instrument.id == Pointing.instrumentid) + except Exception as e: + raise validation_exception( + message="Error parsing 'instruments'", + errors=[f"Required format is a list: '[inst1, inst2...]'", str(e)], + ) + + # Handle spectral filters + if wavelength_regime and wavelength_unit: + try: + if isinstance(wavelength_regime, str): + if "[" in wavelength_regime and "]" in wavelength_regime: + # Parse range from string + wavelength_range = json.loads( + wavelength_regime.replace("(", "[").replace(")", "]") + ) + specmin, specmax = float(wavelength_range[0]), float( + wavelength_range[1] + ) + else: + raise ValueError("Invalid wavelength_regime format") + elif isinstance(wavelength_regime, list): + specmin, specmax = float(wavelength_regime[0]), float( + wavelength_regime[1] + ) + else: + raise ValueError("Invalid wavelength_regime type") + + # Get unit and scale + unit_value = wavelength_unit + try: + unit = [ + w + for w in WavelengthUnits + if int(w) == unit_value or str(w.name) == unit_value + ][0] + scale = WavelengthUnits.get_scale(unit) + specmin = specmin * scale + specmax = specmax * scale + + # Import the spectral handler + from server.utils.spectral import SpectralRangeHandler + + filter_conditions.append( + Pointing.inSpectralRange( + specmin, + specmax, + SpectralRangeHandler.spectralrangetype.wavelength, + ) + ) + except (IndexError, ValueError): + raise validation_exception( + message="Invalid wavelength_unit", + errors=[ + "Valid units are 'angstrom', 'nanometer', and 'micron'" + ], + ) + except Exception as e: + raise validation_exception( + message="Error parsing 'wavelength_regime'", + errors=[f"Required format is a list: '[low, high]'", str(e)], + ) + + if frequency_regime and frequency_unit: + try: + if isinstance(frequency_regime, str): + if "[" in frequency_regime and "]" in frequency_regime: + # Parse range from string + frequency_range = json.loads( + frequency_regime.replace("(", "[").replace(")", "]") + ) + specmin, specmax = float(frequency_range[0]), float( + frequency_range[1] + ) + else: + raise ValueError("Invalid frequency_regime format") + elif isinstance(frequency_regime, list): + specmin, specmax = float(frequency_regime[0]), float( + frequency_regime[1] + ) + else: + raise ValueError("Invalid frequency_regime type") + + # Get unit and scale + unit_value = frequency_unit + try: + unit = [ + f + for f in frequency_units + if int(f) == unit_value or str(f.name) == unit_value + ][0] + scale = frequency_units.get_scale(unit) + specmin = specmin * scale + specmax = specmax * scale + + # Import the spectral handler + from server.utils.spectral import SpectralRangeHandler + + filter_conditions.append( + Pointing.inSpectralRange( + specmin, + specmax, + SpectralRangeHandler.spectralrangetype.frequency, + ) + ) + except (IndexError, ValueError): + raise validation_exception( + message="Invalid frequency_unit", + errors=["Valid units are 'Hz', 'kHz', 'MHz', 'GHz', and 'THz'"], + ) + except Exception as e: + raise validation_exception( + message="Error parsing 'frequency_regime'", + errors=[f"Required format is a list: '[low, high]'", str(e)], + ) + + if energy_regime and energy_unit: + try: + if isinstance(energy_regime, str): + if "[" in energy_regime and "]" in energy_regime: + # Parse range from string + energy_range = json.loads( + energy_regime.replace("(", "[").replace(")", "]") + ) + specmin, specmax = float(energy_range[0]), float( + energy_range[1] + ) + else: + raise ValueError("Invalid energy_regime format") + elif isinstance(energy_regime, list): + specmin, specmax = float(energy_regime[0]), float(energy_regime[1]) + else: + raise ValueError("Invalid energy_regime type") + + # Get unit and scale + unit_value = energy_unit + try: + unit = [ + e + for e in energy_units + if int(e) == unit_value or str(e.name) == unit_value + ][0] + scale = energy_units.get_scale(unit) + specmin = specmin * scale + specmax = specmax * scale + + # Import the spectral handler + from server.utils.spectral import SpectralRangeHandler + + filter_conditions.append( + Pointing.inSpectralRange( + specmin, + specmax, + SpectralRangeHandler.spectralrangetype.energy, + ) + ) + except (IndexError, ValueError): + raise validation_exception( + message="Invalid energy_unit", + errors=["Valid units are 'eV', 'keV', 'MeV', 'GeV', and 'TeV'"], + ) + except Exception as e: + raise validation_exception( + message="Error parsing 'energy_regime'", + errors=[f"Required format is a list: '[low, high]'", str(e)], + ) + + # Handle depth filters + if depth_gt is not None or depth_lt is not None: + # Determine depth unit + depth_unit_value = ( + depth_unit or "ab_mag" + ) # Default to ab_mag if not specified + try: + depth_unit_enum_val = [ + d for d in depth_unit_enum if str(d.name) == depth_unit_value + ][0] + except (IndexError, ValueError): + depth_unit_enum_val = depth_unit_enum.ab_mag # Default + + # Handle depth_gt (query for brighter things) + if depth_gt is not None and isFloat(depth_gt): + if "mag" in depth_unit_enum_val.name: + # For magnitudes, lower values are brighter + filter_conditions.append(Pointing.depth <= float(depth_gt)) + elif "flux" in depth_unit_enum_val.name: + # For flux, higher values are brighter + filter_conditions.append(Pointing.depth >= float(depth_gt)) + + # Handle depth_lt (query for dimmer things) + if depth_lt is not None and isFloat(depth_lt): + if "mag" in depth_unit_enum_val.name: + # For magnitudes, higher values are dimmer + filter_conditions.append(Pointing.depth >= float(depth_lt)) + elif "flux" in depth_unit_enum_val.name: + # For flux, lower values are dimmer + filter_conditions.append(Pointing.depth <= float(depth_lt)) + + # Query the database + pointings = db.query(Pointing).filter(*filter_conditions).all() + + # Let Pydantic handle the conversion of SQLAlchemy models to JSON + # The field_serializer methods in PointingSchema will take care of enum translations + return pointings + except Exception as e: + raise validation_exception(message="Invalid request", errors=[str(e)]) diff --git a/server/routes/pointing/request_doi.py b/server/routes/pointing/request_doi.py new file mode 100644 index 00000000..348be39e --- /dev/null +++ b/server/routes/pointing/request_doi.py @@ -0,0 +1,109 @@ +"""Request DOI for pointings endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.schemas.pointing import DOIRequest +from server.schemas.doi import DOIRequestResponse +from server.auth.auth import get_current_user +from server.utils.error_handling import validation_exception +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.utils import pointing as pointing_utils + +router = APIRouter(tags=["pointings"]) + + +@router.post("/request_doi", response_model=DOIRequestResponse) +async def request_doi( + request: DOIRequest, db: Session = Depends(get_db), user=Depends(get_current_user) +): + """ + Request a DOI for completed pointings. + """ + # Build the filter for pointings + filter_conditions = [Pointing.submitterid == user.id] + + # Handle id or ids (these don't require PointingEvent join) + if request.id: + filter_conditions.append(Pointing.id == request.id) + elif request.ids: + filter_conditions.append(Pointing.id.in_(request.ids)) + + # Only join with PointingEvent if graceid is specified + if request.graceid: + normalized_graceid = GWAlert.graceidfromalternate(request.graceid) + # Query the pointings with explicit join + points = ( + db.query(Pointing) + .join(PointingEvent, Pointing.id == PointingEvent.pointingid) + .filter(*filter_conditions, PointingEvent.graceid == normalized_graceid) + .all() + ) + else: + # Query without join when only using ID filters + points = db.query(Pointing).filter(*filter_conditions).all() + + # Validate and prepare for DOI request + warnings = [] + doi_points = [] + + for p in points: + # Check if pointing is completed and doesn't already have a DOI + if p.status == pointing_status_enum.completed and p.doi_id is None: + doi_points.append(p) + else: + warning_msg = f"Invalid doi request for pointing: {p.id}" + if p.status != pointing_status_enum.completed: + warning_msg += f" (status: {p.status})" + if p.doi_id is not None: + warning_msg += " (already has DOI)" + warnings.append(warning_msg) + + if len(doi_points) == 0: + raise validation_exception( + message="No valid pointings found for DOI request", + errors=["All pointings must be completed and not already have a DOI"], + ) + + # Get the GW event IDs from the pointings + pointing_events = ( + db.query(PointingEvent) + .filter(PointingEvent.pointingid.in_([x.id for x in doi_points])) + .all() + ) + gids = list(set([pe.graceid for pe in pointing_events])) + + if len(gids) > 1: + raise validation_exception( + message="Multiple events detected", + errors=["Pointings must be only for a single GW event for a DOI request"], + ) + + gid = gids[0] + + # Prepare DOI creators + creators = pointing_utils.prepare_doi_creators( + request.creators, request.doi_group_id, user, db + ) + + # Create or use provided DOI + if request.doi_url: + doi_id, doi_url = 0, request.doi_url + else: + doi_id, doi_url = pointing_utils.create_doi_for_pointings( + doi_points, gid, creators, db + ) + + # Update pointing records with DOI information + if doi_id is not None: + for p in doi_points: + p.doi_url = doi_url + p.doi_id = doi_id + + db.commit() + + return DOIRequestResponse(DOI_URL=doi_url, WARNINGS=warnings) diff --git a/server/routes/pointing/router.py b/server/routes/pointing/router.py new file mode 100644 index 00000000..ce57d32f --- /dev/null +++ b/server/routes/pointing/router.py @@ -0,0 +1,22 @@ +"""Consolidated router for all pointing endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .create_pointings import router as create_pointings_router +from .get_pointings import router as get_pointings_router +from .update_pointings import router as update_pointings_router +from .cancel_all import router as cancel_all_router +from .request_doi import router as request_doi_router +from .test_refactoring import router as test_refactoring_router + +# Create the main router that includes all pointing routes +router = APIRouter(tags=["pointings"]) + +# Include all the individual routers +router.include_router(create_pointings_router) +router.include_router(get_pointings_router) +router.include_router(update_pointings_router) +router.include_router(cancel_all_router) +router.include_router(request_doi_router) +router.include_router(test_refactoring_router) diff --git a/server/routes/pointing/test_refactoring.py b/server/routes/pointing/test_refactoring.py new file mode 100644 index 00000000..5b151a76 --- /dev/null +++ b/server/routes/pointing/test_refactoring.py @@ -0,0 +1,11 @@ +"""Test endpoint for refactored pointing routes.""" + +from fastapi import APIRouter + +router = APIRouter(tags=["pointings"]) + + +@router.get("/test_refactoring") +async def test_refactoring(): + """Test endpoint to verify refactored code is active.""" + return {"message": "Refactored pointing routes are active"} diff --git a/server/routes/pointing/update_pointings.py b/server/routes/pointing/update_pointings.py new file mode 100644 index 00000000..61db5b64 --- /dev/null +++ b/server/routes/pointing/update_pointings.py @@ -0,0 +1,54 @@ +"""Update pointings endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from datetime import datetime + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.schemas.pointing import PointingUpdate +from server.auth.auth import get_current_user +from server.utils.error_handling import validation_exception +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum + +router = APIRouter(tags=["pointings"]) + + +@router.post("/update_pointings") +async def update_pointings( + update_pointing: PointingUpdate, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + """ + Update the status of planned pointings. + + Parameters: + - status: The new status for the pointings (only "cancelled" is currently supported) + - ids: List of pointing IDs to update + + Returns: + - Message with the number of updated pointings + """ + try: + # Add a filter to ensure user can only update their own pointings + pointings = ( + db.query(Pointing) + .filter( + Pointing.id.in_(update_pointing.ids), + Pointing.submitterid == user.id, + Pointing.status + == pointing_status_enum.planned, # Only planned pointings can be cancelled + ) + .all() + ) + + for pointing in pointings: + pointing.status = update_pointing.status + pointing.dateupdated = datetime.now() + + db.commit() + return {"message": f"Updated {len(pointings)} pointings successfully."} + except Exception as e: + db.rollback() + raise validation_exception(message="Invalid request", errors=[str(e)]) diff --git a/server/routes/ui/__init__.py b/server/routes/ui/__init__.py new file mode 100644 index 00000000..4ecf6f96 --- /dev/null +++ b/server/routes/ui/__init__.py @@ -0,0 +1 @@ +"""UI route module.""" diff --git a/server/routes/ui/alert_instruments_footprints.py b/server/routes/ui/alert_instruments_footprints.py new file mode 100644 index 00000000..77b0c1d0 --- /dev/null +++ b/server/routes/ui/alert_instruments_footprints.py @@ -0,0 +1,186 @@ +"""Get alert instruments footprints endpoint.""" + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from sqlalchemy import func, or_ + +from server.db.database import get_db +from server.db.models.instrument import Instrument +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_alertinstruments_footprints") +async def get_alert_instruments_footprints( + graceid: str = None, + pointing_status: str = None, + tos_mjd: float = None, + db: Session = Depends(get_db), +): + """Get footprints of instruments that observed a specific alert.""" + from server.utils.function import ( + sanatize_pointing, + project_footprint, + sanatize_footprint_ccds, + ) + from server.db.models.instrument import FootprintCCD + import json + import hashlib + from server.utils.gwtm_io import get_cached_file, set_cached_file + from server.config import settings + + # First find the alert by graceid + alert = db.query(GWAlert).filter(GWAlert.graceid == graceid).first() + if not alert: + raise HTTPException(status_code=404, detail="Alert not found") + + # Set default status if none provided + if pointing_status is None: + pointing_status = "completed" + + # Build pointing filter - need to join with PointingEvent to get alert association + from server.db.models.pointing_event import PointingEvent + + pointing_filter = [] + pointing_filter.append(PointingEvent.graceid == graceid) + pointing_filter.append(PointingEvent.pointingid == Pointing.id) + + # Status filtering + if pointing_status == "pandc": + pointing_filter.append( + or_(Pointing.status == "completed", Pointing.status == "planned") + ) + elif pointing_status not in ["all", ""]: + from server.core.enums.pointingstatus import ( + PointingStatus as pointing_status_enum, + ) + + if pointing_status == "completed": + pointing_filter.append(Pointing.status == pointing_status_enum.completed) + elif pointing_status == "planned": + pointing_filter.append(Pointing.status == pointing_status_enum.planned) + elif pointing_status == "cancelled": + pointing_filter.append(Pointing.status == pointing_status_enum.cancelled) + + # Get pointing info + pointing_info = ( + db.query( + Pointing.id, + Pointing.instrumentid, + Pointing.pos_angle, + Pointing.time, + func.ST_AsText(Pointing.position).label("position"), + Pointing.band, + Pointing.depth, + Pointing.depth_unit, + Pointing.status, + ) + .join(PointingEvent, PointingEvent.pointingid == Pointing.id) + .filter(*pointing_filter) + .all() + ) + + # Cache key based on pointing IDs + pointing_ids = [p.id for p in pointing_info] + hash_pointing_ids = hashlib.sha1(json.dumps(pointing_ids).encode()).hexdigest() + cache_key = f"cache/footprint_{graceid}_{pointing_status}_{hash_pointing_ids}" + + # Try to get from cache first + cached_overlays = get_cached_file(cache_key, settings) + + if cached_overlays: + return json.loads(cached_overlays) + + # Not in cache, generate fresh data + instrument_ids = [p.instrumentid for p in pointing_info] + + # Get instrument info + instrumentinfo = ( + db.query(Instrument.instrument_name, Instrument.nickname, Instrument.id) + .filter(Instrument.id.in_(instrument_ids)) + .all() + ) + + # Get footprint info + from server.db.models.instrument import FootprintCCD + + footprintinfo = ( + db.query( + func.ST_AsText(FootprintCCD.footprint).label("footprint"), + FootprintCCD.instrumentid, + ) + .filter(FootprintCCD.instrumentid.in_(instrument_ids)) + .all() + ) + + # Prepare colors + colorlist = [ + "#ffe119", + "#4363d8", + "#f58231", + "#42d4f4", + "#f032e6", + "#fabebe", + "#469990", + "#e6beff", + "#9A6324", + "#fffac8", + "#800000", + "#aaffc3", + "#000075", + "#a9a9a9", + ] + + # Generate overlays + inst_overlays = [] + + for i, inst in enumerate([x for x in instrumentinfo if x.id != 49]): + name = ( + inst.nickname + if inst.nickname and inst.nickname != "None" + else inst.instrument_name + ) + + try: + color = colorlist[i] + except IndexError: + color = "#" + format(inst.id % 0xFFFFFF, "06x") + + footprint_ccds = [ + x.footprint for x in footprintinfo if x.instrumentid == inst.id + ] + sanatized_ccds = sanatize_footprint_ccds(footprint_ccds) + inst_pointings = [x for x in pointing_info if x.instrumentid == inst.id] + pointing_geometries = [] + + for p in inst_pointings: + import astropy.time + + t = astropy.time.Time([p.time]) + ra, dec = sanatize_pointing(p.position) + + for ccd in sanatized_ccds: + pointing_footprint = project_footprint(ccd, ra, dec, p.pos_angle) + pointing_geometries.append( + { + "polygon": pointing_footprint, + "time": round(t.mjd[0] - tos_mjd, 3) if tos_mjd else 0, + } + ) + + inst_overlays.append( + { + "display": True, + "name": name, + "color": color, + "contours": pointing_geometries, + } + ) + + # Cache the result + set_cached_file(cache_key, inst_overlays, settings) + + return inst_overlays diff --git a/server/routes/ui/alert_type.py b/server/routes/ui/alert_type.py new file mode 100644 index 00000000..23372eab --- /dev/null +++ b/server/routes/ui/alert_type.py @@ -0,0 +1,158 @@ +"""Alert type endpoint.""" + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_alerttype") +async def ajax_get_eventcontour(urlid: str, db: Session = Depends(get_db)): + """Get event contour and alert information.""" + from server.utils.function import get_farrate_farunit, polygons2footprints + from server.utils.gwtm_io import download_gwtm_file + from server.config import settings + import pandas as pd + + # Parse the URL ID to get alert ID and alert type + url_parts = urlid.split("_") + alert_id = url_parts[0] + alert_type = url_parts[1] + if len(url_parts) > 2: + alert_type += url_parts[2] + + # Get the alert + alert = db.query(GWAlert).filter(GWAlert.id == int(alert_id)).first() + if not alert: + raise HTTPException(status_code=404, detail="Alert not found") + + # Determine storage path + s3path = "fit" if alert.role == "observation" else "test" + + # Format FAR (False Alarm Rate) for human readability + human_far = "" + if alert.far != 0: + far_rate, far_unit = get_farrate_farunit(alert.far) + human_far = f"once per {round(far_rate, 2)} {far_unit}" + + # Format time coincidence FAR + human_time_coinc_far = "" + if alert.time_coincidence_far != 0 and alert.time_coincidence_far is not None: + time_coinc_farrate, time_coinc_farunit = get_farrate_farunit( + alert.time_coincidence_far + ) + time_coinc_farrate = round(time_coinc_farrate, 2) + human_time_coinc_far = ( + f"once per {round(time_coinc_farrate, 2)} {time_coinc_farunit}" + ) + + # Format time-sky position coincidence FAR + human_time_skypos_coinc_far = "" + if ( + alert.time_sky_position_coincidence_far != 0 + and alert.time_sky_position_coincidence_far is not None + ): + time_skypos_coinc_farrate, time_skypos_coinc_farunit = get_farrate_farunit( + alert.time_sky_position_coincidence_far + ) + time_skypos_coinc_farrate = round(time_skypos_coinc_farrate, 2) + human_time_skypos_coinc_far = f"once per {round(time_skypos_coinc_farrate, 2)} {time_skypos_coinc_farunit}" + + # Format time difference + if alert.time_difference is not None: + alert.time_difference = round(alert.time_difference, 3) + + # Format distance and error + distance_with_error = "" + if alert.distance is not None: + alert.distance = round(alert.distance, 3) + if alert.distance_error is not None: + alert.distance_error = round(alert.distance_error, 3) + distance_with_error = f"{alert.distance} ± {alert.distance_error} Mpc" + + # Format areas + if alert.area_50 is not None: + alert.area_50 = f"{round(alert.area_50, 3)} deg2" + if alert.area_90 is not None: + alert.area_90 = f"{round(alert.area_90, 3)} deg2" + + # Round probability values + if alert.prob_bns is not None: + alert.prob_bns = round(alert.prob_bns, 5) + if alert.prob_nsbh is not None: + alert.prob_nsbh = round(alert.prob_nsbh, 5) + if alert.prob_gap is not None: + alert.prob_gap = round(alert.prob_gap, 5) + if alert.prob_bbh is not None: + alert.prob_bbh = round(alert.prob_bbh, 5) + if alert.prob_terrestrial is not None: + alert.prob_terrestrial = round(alert.prob_terrestrial, 5) + if alert.prob_hasns is not None: + alert.prob_hasns = round(alert.prob_hasns, 5) + if alert.prob_hasremenant is not None: + alert.prob_hasremenant = round(alert.prob_hasremenant, 5) + + # Prepare detection overlays + detection_overlays = [] + path_info = alert.graceid + "-" + alert_type + + # Try to download contours + contour_path = f"{s3path}/{path_info}-contours-smooth.json" + try: + contours_data = download_gwtm_file( + contour_path, source=settings.STORAGE_BUCKET_SOURCE, config=settings + ) + contours_df = pd.read_json(contours_data) + + contour_geometry = [] + for contour in contours_df["features"]: + contour_geometry.extend(contour["geometry"]["coordinates"]) + + detection_overlays.append( + { + "display": True, + "name": "GW Contour", + "color": "#e6194B", + "contours": polygons2footprints(contour_geometry, 0), + } + ) + except Exception as e: + print(f"Error downloading contours: {str(e)}") + + # Prepare payload + payload = { + "hidden_alertid": alert_id, + "detection_overlays": detection_overlays, + "alert_group": alert.group, + "alert_detectors": alert.detectors, + "alert_time_of_signal": alert.time_of_signal, + "alert_timesent": alert.timesent, + "alert_human_far": human_far, + "alert_distance_plus_error": distance_with_error, + "alert_centralfreq": alert.centralfreq, + "alert_duration": alert.duration, + "alert_prob_bns": alert.prob_bns, + "alert_prob_nsbh": alert.prob_nsbh, + "alert_prob_gap": alert.prob_gap, + "alert_prob_bbh": alert.prob_bbh, + "alert_prob_terrestrial": alert.prob_terrestrial, + "alert_prob_hasns": alert.prob_hasns, + "alert_prob_hasremenant": alert.prob_hasremenant, + "alert_area_50": alert.area_50, + "alert_area_90": alert.area_90, + "alert_avgra": alert.avgra, + "alert_avgdec": alert.avgdec, + "alert_gcn_notice_id": alert.gcn_notice_id, + "alert_ivorn": alert.ivorn, + "alert_ext_coinc_observatory": alert.ext_coinc_observatory, + "alert_ext_coinc_search": alert.ext_coinc_search, + "alert_time_difference": alert.time_difference, + "alert_time_coincidence_far": human_time_coinc_far, + "alert_time_sky_position_coincidence_far": human_time_skypos_coinc_far, + "selected_alert_type": alert.alert_type, + } + + return payload diff --git a/server/routes/ui/candidate_fetch.py b/server/routes/ui/candidate_fetch.py new file mode 100644 index 00000000..6e56c33f --- /dev/null +++ b/server/routes/ui/candidate_fetch.py @@ -0,0 +1,48 @@ +"""Candidate fetch endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.candidate import GWCandidate +from server.db.models.gw_alert import GWAlert + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_candidate") +async def ajax_candidate_fetch(graceid: str, db: Session = Depends(get_db)): + """Get candidates associated with a GW event.""" + from server.utils.function import sanatize_pointing, sanatize_candidate_info + import shapely.wkb + + # Normalize the graceid - maintain backward compatibility + normalized_graceid = GWAlert.graceidfromalternate(graceid) + + # Get candidates for this event + candidates = ( + db.query(GWCandidate).filter(GWCandidate.graceid == normalized_graceid).all() + ) + + markers = [] + payload = [] + + for c in candidates: + clean_position = shapely.wkb.loads(bytes(c.position.data), hex=True) + position_str = str(clean_position) + ra, dec = sanatize_pointing(position_str) + + markers.append( + { + "name": c.candidate_name, + "ra": ra, + "dec": dec, + "shape": "star", + "info": sanatize_candidate_info(c, ra, dec), + } + ) + + if markers: + payload.append({"name": "Candidates", "color": "", "markers": markers}) + + return payload diff --git a/server/routes/ui/coverage_calculator.py b/server/routes/ui/coverage_calculator.py new file mode 100644 index 00000000..f2c0ed85 --- /dev/null +++ b/server/routes/ui/coverage_calculator.py @@ -0,0 +1,376 @@ +"""Coverage calculator endpoint.""" + +from fastapi import APIRouter, Depends, Request, HTTPException +from sqlalchemy.orm import Session +from sqlalchemy import func, or_ + +from server.db.database import get_db +from server.auth.auth import get_current_user + +router = APIRouter(tags=["UI"]) + + +@router.post("/ajax_coverage_calculator") +async def coverage_calculator( + request: Request, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Calculate coverage statistics for an alert using real HEALPix implementation.""" + import numpy as np + import healpy as hp + import hashlib + import json + import astropy.coordinates + import plotly + import plotly.graph_objects as go + from plotly.subplots import make_subplots + import tempfile + from datetime import datetime + from sqlalchemy import func, or_ + from server.db.models.pointing_event import PointingEvent + from server.utils.function import ( + sanatize_pointing, + sanatize_footprint_ccds, + project_footprint, + ) + from server.utils.gwtm_io import ( + download_gwtm_file, + get_cached_file, + set_cached_file, + ) + from server.config import settings + + data = await request.json() + + graceid = data.get("graceid") + if not graceid: + raise HTTPException(status_code=400, detail="Missing graceid") + + # Get equivalent params from request data + mappathinfo = data.get("mappathinfo") + inst_cov = data.get("inst_cov", "") + band_cov = data.get("band_cov", "") + depth = data.get("depth_cov") + depth_unit = data.get("depth_unit", "") + approx_cov = data.get("approx_cov", 1) == 1 + spec_range_type = data.get("spec_range_type", "") + spec_range_unit = data.get("spec_range_unit", "") + spec_range_low = data.get("spec_range_low") + spec_range_high = data.get("spec_range_high") + + # Create cache key + cache_params = ( + f"{graceid}_{mappathinfo}_{inst_cov}_{depth}_{depth_unit}_{approx_cov}" + ) + cache_key = f"coverage_calc_{hashlib.sha1(cache_params.encode()).hexdigest()}" + + # Try to get from cache first + cached_result = get_cached_file(cache_key, settings) + if cached_result: + times, probs, areas = ( + cached_result["times"], + cached_result["probs"], + cached_result["areas"], + ) + else: + # Calculate coverage using real HEALPix implementation + times, probs, areas = await calculate_healpix_coverage( + graceid, + mappathinfo, + inst_cov, + band_cov, + depth, + depth_unit, + approx_cov, + spec_range_low, + spec_range_high, + spec_range_type, + cache_key, + db, + ) + + # Generate the plot + fig = make_subplots(specs=[[{"secondary_y": True}]]) + + fig.add_trace( + go.Scatter( + x=times, y=[prob * 100 for prob in probs], mode="lines", name="Probability" + ), + secondary_y=False, + ) + + fig.add_trace( + go.Scatter(x=times, y=areas, mode="lines", name="Area"), secondary_y=True + ) + + fig.update_xaxes(title_text="Hours since GW T0") + fig.update_yaxes( + title_text="Percent of GW localization posterior covered", secondary_y=False + ) + fig.update_yaxes(title_text="Area coverage (deg2)", secondary_y=True) + + coverage_div = plotly.offline.plot( + fig, output_type="div", include_plotlyjs=False, show_link=False + ) + + return {"plot_html": coverage_div} + + +async def calculate_healpix_coverage( + graceid, + mappathinfo, + inst_cov, + band_cov, + depth, + depth_unit, + approx_cov, + spec_range_low, + spec_range_high, + spec_range_type, + cache_key, + db, +): + """Calculate real HEALPix-based coverage statistics.""" + import numpy as np + import healpy as hp + import astropy.coordinates + import tempfile + from server.utils.function import ( + sanatize_pointing, + sanatize_footprint_ccds, + project_footprint, + isFloat, + ) + from server.utils.gwtm_io import download_gwtm_file, set_cached_file + from server.config import settings + from server.db.models.pointing_event import PointingEvent + from server.db.models.pointing import Pointing + from server.db.models.gw_alert import GWAlert + from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum + + # Handle instrument approximations for large-scale instruments + approx_dict = {47: 76, 38: 98} # ZTF to ZTF_approx # DECam to DECam_approx + + areas = [] + times = [] + probs = [] + + # Download and read the HEALPix map + try: + with tempfile.NamedTemporaryFile() as f: + tmpdata = download_gwtm_file( + mappathinfo, + source=settings.STORAGE_BUCKET_SOURCE, + config=settings, + decode=False, + ) + f.write(tmpdata) + GWmap = hp.read_map(f.name) + nside = hp.npix2nside(len(GWmap)) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Calculator ERROR: Map not found. {str(e)}" + ) + + # Build pointing filter + pointing_filter = [] + pointing_filter.append(PointingEvent.graceid == graceid) + pointing_filter.append(Pointing.status == pointing_status_enum.completed) + pointing_filter.append(PointingEvent.pointingid == Pointing.id) + pointing_filter.append(Pointing.instrumentid != 49) # Exclude instrument 49 + + if inst_cov: + insts_cov = [int(x) for x in inst_cov.split(",")] + pointing_filter.append(Pointing.instrumentid.in_(insts_cov)) + + if depth_unit and depth_unit != "None": + from server.core.enums.depthunit import DepthUnit as depth_unit_enum + + try: + unit_enum = depth_unit_enum[depth_unit] + pointing_filter.append(Pointing.depth_unit == unit_enum) + except KeyError: + pass + + if depth and isFloat(depth): + depth_val = float(depth) + if "mag" in depth_unit: + pointing_filter.append(Pointing.depth >= depth_val) + elif "flux" in depth_unit: + pointing_filter.append(Pointing.depth <= depth_val) + else: + raise HTTPException(status_code=400, detail="Unknown depth unit") + + # Handle spectral range filtering if provided + if spec_range_low and spec_range_high and spec_range_type: + from server.utils.spectral import SpectralRangeHandler + from server.utils.function import isFloat + + try: + slow, shigh = None, None + if isFloat(spec_range_low) and isFloat(spec_range_high): + slow = float(spec_range_low) + shigh = float(spec_range_high) + + # Convert spectral range to common units and apply filter + if spec_range_type == "wavelength": + from server.core.enums.wavelengthunits import WavelengthUnits + + unit = [x for x in WavelengthUnits if spec_range_unit == x.name][0] + scale = WavelengthUnits.get_scale(unit) + slow = slow * scale + shigh = shigh * scale + stype = SpectralRangeHandler.spectralrangetype.wavelength + elif spec_range_type == "energy": + from server.core.enums.energyunits import EnergyUnits + + unit = [x for x in EnergyUnits if spec_range_unit == x.name][0] + scale = EnergyUnits.get_scale(unit) + slow = slow * scale + shigh = shigh * scale + stype = SpectralRangeHandler.spectralrangetype.energy + elif spec_range_type == "frequency": + from server.core.enums.frequencyunits import FrequencyUnits + + unit = [x for x in FrequencyUnits if spec_range_unit == x.name][0] + scale = FrequencyUnits.get_scale(unit) + slow = slow * scale + shigh = shigh * scale + stype = SpectralRangeHandler.spectralrangetype.frequency + else: + stype = None + + if stype is not None: + pointing_filter.append(Pointing.inSpectralRange(slow, shigh, stype)) + except Exception: + # If spectral filtering fails, continue without it + pass + + # Get sorted pointings + pointings_sorted = ( + db.query( + Pointing.id, + Pointing.instrumentid, + Pointing.pos_angle, + func.ST_AsText(Pointing.position).label("position"), + Pointing.band, + Pointing.depth, + Pointing.time, + ) + .join(PointingEvent, PointingEvent.pointingid == Pointing.id) + .filter(*pointing_filter) + .order_by(Pointing.time.asc()) + .all() + ) + + # Get instrument IDs and handle approximations + instrumentids = [p.instrumentid for p in pointings_sorted] + + # Add approximation instruments if needed + if approx_cov: + for apid in approx_dict.keys(): + if apid in instrumentids: + instrumentids.append(approx_dict[apid]) + + # Get footprint information + from server.db.models.instrument import FootprintCCD + + footprintinfo = ( + db.query( + func.ST_AsText(FootprintCCD.footprint).label("footprint"), + FootprintCCD.instrumentid, + ) + .filter(FootprintCCD.instrumentid.in_(instrumentids)) + .all() + ) + + # Get GW T0 time + time_of_signal = ( + db.query(GWAlert.time_of_signal) + .filter(GWAlert.graceid == graceid, GWAlert.time_of_signal.isnot(None)) + .order_by(GWAlert.datecreated.desc()) + .first() + ) + + if not time_of_signal or not time_of_signal[0]: + raise HTTPException( + status_code=400, detail="ERROR: Please contact administrator" + ) + + time_of_signal = time_of_signal[0] + + # Initialize HEALPix calculation variables + qps = [] + qpsarea = [] + + NSIDE4area = 512 # This gives pixarea of 0.013 deg^2 per pixel + pixarea = hp.nside2pixarea(NSIDE4area, degrees=True) + + # Process each pointing + for p in pointings_sorted: + ra, dec = sanatize_pointing(p.position) + + # Select appropriate footprint based on approximation setting + if approx_cov and p.instrumentid in approx_dict.keys(): + footprint_ccds = [ + x.footprint + for x in footprintinfo + if x.instrumentid == approx_dict[p.instrumentid] + ] + else: + footprint_ccds = [ + x.footprint for x in footprintinfo if x.instrumentid == p.instrumentid + ] + + sanatized_ccds = sanatize_footprint_ccds(footprint_ccds) + + # Process each CCD footprint + for ccd in sanatized_ccds: + pointing_footprint = project_footprint(ccd, ra, dec, p.pos_angle) + + # Extract RA/Dec coordinates from footprint + ras_poly = [x[0] for x in pointing_footprint][:-1] + decs_poly = [x[1] for x in pointing_footprint][:-1] + + # Convert to cartesian coordinates for HEALPix + xyzpoly = astropy.coordinates.spherical_to_cartesian( + 1, np.deg2rad(decs_poly), np.deg2rad(ras_poly) + ) + + # Query HEALPix pixels within the polygon + qp = hp.query_polygon(nside, np.array(xyzpoly).T, inclusive=True) + qps.extend(qp) + + # Separate calculation for area coverage with higher resolution + qparea = hp.query_polygon(NSIDE4area, np.array(xyzpoly).T, inclusive=True) + qpsarea.extend(qparea) + + # Deduplicate indices so pixels aren't double-counted + deduped_indices = list(dict.fromkeys(qps)) + deduped_indices_area = list(dict.fromkeys(qpsarea)) + + # Calculate area coverage + area = pixarea * len(deduped_indices_area) + + # Calculate probability coverage by summing GW map pixel values + prob = 0 + for ind in deduped_indices: + prob += GWmap[ind] + + # Calculate elapsed time since GW trigger + elapsed = (p.time - time_of_signal).total_seconds() / 3600 + + times.append(elapsed) + probs.append(prob) + areas.append(area) + + # Cache the results + cache_file = { + "times": times, + "probs": probs, + "areas": areas, + } + set_cached_file(cache_key, cache_file, settings) + + return times, probs, areas diff --git a/server/routes/ui/event_galaxies.py b/server/routes/ui/event_galaxies.py new file mode 100644 index 00000000..56e712f3 --- /dev/null +++ b/server/routes/ui/event_galaxies.py @@ -0,0 +1,62 @@ +"""Event galaxies endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from sqlalchemy import func + +from server.db.database import get_db +from server.db.models.gw_galaxy import GWGalaxyList, GWGalaxyEntry + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_event_galaxies") +async def ajax_event_galaxies(alertid: str, db: Session = Depends(get_db)): + """Get galaxies associated with an event.""" + from server.utils.function import sanatize_pointing, sanatize_gal_info + + event_galaxies = [] + + # Get galaxy lists for this alert + gal_lists = db.query(GWGalaxyList).filter(GWGalaxyList.alertid == alertid).all() + + if not gal_lists: + return event_galaxies + + gal_list_ids = list(set([x.id for x in gal_lists])) + + # Get galaxy entries for these lists + gal_entries = ( + db.query( + GWGalaxyEntry.name, + func.ST_AsText(GWGalaxyEntry.position).label("position"), + GWGalaxyEntry.score, + GWGalaxyEntry.info, + GWGalaxyEntry.listid, + GWGalaxyEntry.rank, + ) + .filter(GWGalaxyEntry.listid.in_(gal_list_ids)) + .all() + ) + + # Process each list + for glist in gal_lists: + markers = [] + entries = [x for x in gal_entries if x.listid == glist.id] + + for e in entries: + ra, dec = sanatize_pointing(e.position) + markers.append( + { + "name": e.name, + "ra": ra, + "dec": dec, + "info": sanatize_gal_info(e, glist), + } + ) + + event_galaxies.append( + {"name": glist.groupname, "color": "", "markers": markers} + ) + + return event_galaxies diff --git a/server/routes/ui/grade_calculator.py b/server/routes/ui/grade_calculator.py new file mode 100644 index 00000000..7edd6643 --- /dev/null +++ b/server/routes/ui/grade_calculator.py @@ -0,0 +1,119 @@ +"""Grade calculator endpoint.""" + +from fastapi import APIRouter, Depends, Request, HTTPException +from sqlalchemy.orm import Session +from sqlalchemy import func + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.auth.auth import get_current_user + +router = APIRouter(tags=["UI"]) + + +@router.post("/ajax_grade_calculator") +async def grade_calculator( + request: Request, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Calculate grades for pointings based on various metrics.""" + data = await request.json() + + pointing_ids = data.get("pointing_ids", []) + if not pointing_ids: + raise HTTPException(status_code=400, detail="No pointings specified") + + # Get the pointings + pointings = db.query(Pointing).filter(Pointing.id.in_(pointing_ids)).all() + + # Calculate grades for each pointing based on actual metrics + results = {} + + # Get associated GW alerts for time grading + pointing_events = ( + db.query(PointingEvent).filter(PointingEvent.pointingid.in_(pointing_ids)).all() + ) + + event_map = {pe.pointingid: pe.graceid for pe in pointing_events} + + for pointing in pointings: + # Get the associated GW alert + graceid = event_map.get(pointing.id) + time_grade = 0.5 # Default + position_grade = 0.5 # Default + depth_grade = 0.5 # Default + + if graceid: + alert = db.query(GWAlert).filter(GWAlert.graceid == graceid).first() + if alert and alert.time_of_signal and pointing.time: + # Time grade: earlier observations get higher grades + time_diff_hours = ( + pointing.time - alert.time_of_signal + ).total_seconds() / 3600 + if time_diff_hours <= 1: + time_grade = 1.0 + elif time_diff_hours <= 6: + time_grade = 0.9 + elif time_diff_hours <= 24: + time_grade = 0.7 + elif time_diff_hours <= 72: + time_grade = 0.5 + else: + time_grade = 0.3 + + # Position grade: simplified calculation based on alert coordinates + if graceid and alert and alert.avgra is not None and alert.avgdec is not None: + # Get pointing coordinates + position_result = ( + db.query(func.ST_AsText(Pointing.position)) + .filter(Pointing.id == pointing.id) + .first() + ) + if position_result and position_result[0]: + try: + pos_str = position_result[0] + pointing_ra = float(pos_str.split("POINT(")[1].split(" ")[0]) + pointing_dec = float(pos_str.split(" ")[1].split(")")[0]) + + # Simple angular distance calculation (rough approximation) + ra_diff = abs(pointing_ra - alert.avgra) + dec_diff = abs(pointing_dec - alert.avgdec) + angular_dist = (ra_diff**2 + dec_diff**2) ** 0.5 + + # Grade based on distance from alert center + if angular_dist <= 5: + position_grade = 1.0 + elif angular_dist <= 15: + position_grade = 0.8 + elif angular_dist <= 30: + position_grade = 0.6 + else: + position_grade = 0.3 + except: + position_grade = 0.5 + + # Depth grade: deeper observations get higher grades + if pointing.depth is not None: + if pointing.depth >= 23: + depth_grade = 1.0 + elif pointing.depth >= 21: + depth_grade = 0.8 + elif pointing.depth >= 19: + depth_grade = 0.6 + else: + depth_grade = 0.4 + + # Calculate weighted overall grade + overall_grade = time_grade * 0.4 + position_grade * 0.4 + depth_grade * 0.2 + + results[pointing.id] = { + "time_grade": round(time_grade, 2), + "position_grade": round(position_grade, 2), + "depth_grade": round(depth_grade, 2), + "overall_grade": round(overall_grade, 2), + } + + return results diff --git a/server/routes/ui/icecube_notice.py b/server/routes/ui/icecube_notice.py new file mode 100644 index 00000000..fb802619 --- /dev/null +++ b/server/routes/ui/icecube_notice.py @@ -0,0 +1,60 @@ +"""IceCube notice endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.icecube import IceCubeNotice, IceCubeNoticeCoincEvent + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_icecube_notice") +async def ajax_icecube_notice(graceid: str, db: Session = Depends(get_db)): + """Get IceCube notices associated with a GW event.""" + from server.utils.function import sanatize_icecube_event + + return_events = [] + + # Get IceCube notices for this event + icecube_notices = ( + db.query(IceCubeNotice).filter(IceCubeNotice.graceid == graceid).all() + ) + + if not icecube_notices: + return return_events + + icecube_notice_ids = list(set([notice.id for notice in icecube_notices])) + + # Get coincident events for these notices + icecube_notice_events = ( + db.query(IceCubeNoticeCoincEvent) + .filter(IceCubeNoticeCoincEvent.icecube_notice_id.in_(icecube_notice_ids)) + .all() + ) + + # Process each notice + for notice in icecube_notices: + markers = [] + events = [x for x in icecube_notice_events if x.icecube_notice_id == notice.id] + + for i, e in enumerate(events): + markers.append( + { + "name": f"ICN_EVENT_{e.id}", + "ra": e.ra, + "dec": e.dec, + "radius": e.ra_uncertainty, + "info": sanatize_icecube_event(e, notice), + } + ) + + return_events.append( + { + "name": f"ICECUBENotice{notice.id}", + "color": "#324E72", + "markers": markers, + } + ) + + return return_events diff --git a/server/routes/ui/pointing_from_id.py b/server/routes/ui/pointing_from_id.py new file mode 100644 index 00000000..23c078a4 --- /dev/null +++ b/server/routes/ui/pointing_from_id.py @@ -0,0 +1,84 @@ +"""Get pointing from ID endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from sqlalchemy import func + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.auth.auth import get_current_user + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_pointingfromid") +async def get_pointing_fromID( + id: str, db: Session = Depends(get_db), current_user=Depends(get_current_user) +): + """Get pointing details by ID for the current user's planned pointings.""" + from server.utils.function import isInt + from server.db.models.gw_alert import GWAlert + from server.db.models.pointing_event import PointingEvent + from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum + + if not id or not isInt(id): + return {} + + # Convert to integer + pointing_id = int(id) + + # Query pointings with filter conditions + filters = [ + Pointing.submitterid == current_user.id, + Pointing.status == pointing_status_enum.planned, + Pointing.id == pointing_id, + ] + + pointing = db.query(Pointing).filter(*filters).first() + + if not pointing: + return {} + + # Get the alert for this pointing + pointing_event = ( + db.query(PointingEvent).filter(PointingEvent.pointingid == pointing.id).first() + ) + if not pointing_event: + return {} + + alert = db.query(GWAlert).filter(GWAlert.graceid == pointing_event.graceid).first() + if not alert: + return {} + + # Extract position + position_result = ( + db.query(func.ST_AsText(Pointing.position)) + .filter(Pointing.id == pointing_id) + .first() + ) + + if not position_result or not position_result[0]: + return {} + + position = position_result[0] + ra = position.split("POINT(")[1].split(" ")[0] + dec = position.split("POINT(")[1].split(" ")[1].split(")")[0] + + # Get instrument details + instrument = ( + db.query(Instrument).filter(Instrument.id == pointing.instrumentid).first() + ) + + # Prepare response + pointing_json = { + "ra": ra, + "dec": dec, + "graceid": pointing_event.graceid, + "instrument": f"{pointing.instrumentid}_{instrument.InstrumentType.name if instrument else ''}", + "band": pointing.band.name if pointing.band else "", + "depth": pointing.depth, + "depth_err": pointing.depth_err, + } + + return pointing_json diff --git a/server/routes/ui/preview_footprint.py b/server/routes/ui/preview_footprint.py new file mode 100644 index 00000000..7290e8f3 --- /dev/null +++ b/server/routes/ui/preview_footprint.py @@ -0,0 +1,109 @@ +"""Preview footprint endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_preview_footprint") +async def preview_footprint( + ra: float, + dec: float, + radius: float = None, + height: float = None, + width: float = None, + shape: str = "circle", + polygon: str = None, + db: Session = Depends(get_db), +): + """Generate a preview of an instrument footprint.""" + import math + import json + import plotly + import plotly.graph_objects as go + + # This is a UI helper endpoint to visualize a footprint before saving + # It generates the appropriate visualization for the given parameters + vertices = [] + + if shape.lower() == "circle" and radius: + # For circle, generate points around the circumference + circle_points = [] + for i in range(36): # 36 points for a smooth circle + angle = i * 10 * (math.pi / 180) # 10 degrees in radians + # Proper spherical coordinate calculation for circles + # Convert angle to offset in RA/Dec using spherical trigonometry + x = radius * math.cos(math.radians(90 - i * 10)) + y = radius * math.sin(math.radians(90 - i * 10)) + # Adjust for spherical coordinates + if abs(x) < 1e-10: + x = 0.0 + if abs(y) < 1e-10: + y = 0.0 + point_ra = ra + x + point_dec = dec + y + circle_points.append([point_ra, point_dec]) + + # Close the polygon + circle_points.append(circle_points[0]) + vertices.append(circle_points) + + elif shape.lower() == "rectangle" and height and width: + # For rectangle, generate corners + # Convert height/width in degrees to ra/dec coordinates + # Proper calculation accounting for coordinate system + half_width = width / 2 + half_height = height / 2 + + # No cos(dec) correction needed for simple rectangular footprints + ra_offset = half_width + + rect_points = [ + [ra - ra_offset, dec - half_height], # bottom left + [ra - ra_offset, dec + half_height], # top left + [ra + ra_offset, dec + half_height], # top right + [ra + ra_offset, dec - half_height], # bottom right + [ra - ra_offset, dec - half_height], # close the polygon + ] + vertices.append(rect_points) + + elif shape.lower() == "polygon" and polygon: + # For custom polygon, parse the points + try: + poly_points = json.loads(polygon) + poly_points.append(poly_points[0]) # Close the polygon + vertices.append(poly_points) + except json.JSONDecodeError: + return {"error": "Invalid polygon format"} + else: + return {"error": "Invalid shape type or missing required parameters"} + + # Create a plotly figure + traces = [] + for vert in vertices: + xs = [v[0] for v in vert] + ys = [v[1] for v in vert] + trace = go.Scatter( + x=xs, y=ys, line_color="blue", fill="tozeroy", fillcolor="violet" + ) + traces.append(trace) + + fig = go.Figure(data=traces) + fig.update_layout( + showlegend=False, + xaxis_title="degrees", + yaxis_title="degrees", + yaxis=dict( + matches="x", + scaleanchor="x", + scaleratio=1, + constrain="domain", + ), + ) + + # Convert to JSON for return + graphJSON = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) + return graphJSON diff --git a/server/routes/ui/request_doi.py b/server/routes/ui/request_doi.py new file mode 100644 index 00000000..64e937a3 --- /dev/null +++ b/server/routes/ui/request_doi.py @@ -0,0 +1,97 @@ +"""Request DOI endpoint.""" + +from typing import Optional +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.db.models.users import Users +from server.auth.auth import get_current_user + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_request_doi") +async def ajax_request_doi( + graceid: str, + ids: str = "", + doi_group_id: Optional[str] = None, + doi_url: Optional[str] = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Request a DOI for a set of pointings.""" + from server.utils.function import create_pointing_doi + from server.db.models.gw_alert import GWAlert + from server.db.models.pointing_event import PointingEvent + + # Normalize the graceid - maintain backward compatibility + normalized_graceid = GWAlert.alternatefromgraceid(graceid) + + if not ids: + return "" + + # Convert IDs to list of integers + pointing_ids = [int(x) for x in ids.split(",")] + + # Get all pointings with these IDs that are associated with the graceid + points = ( + db.query(Pointing) + .join(PointingEvent, PointingEvent.pointingid == Pointing.id) + .filter( + Pointing.id.in_(pointing_ids), PointingEvent.graceid == normalized_graceid + ) + .all() + ) + + # Get user information + user = db.query(Users).filter(Users.id == current_user.id).first() + + # Set up creators list + if doi_group_id: + # Get creator list from DOI author group + from server.db.models.doi_author import DOIAuthor + + try: + valid, creators_list = DOIAuthor.construct_creators( + doi_group_id, user.id, db + ) + if valid: + creators = creators_list + else: + # Fall back to current user if group is invalid + creators = [{"name": f"{user.firstname} {user.lastname}"}] + except: + # Fall back to current user if there's an error + creators = [{"name": f"{user.firstname} {user.lastname}"}] + else: + creators = [{"name": f"{user.firstname} {user.lastname}"}] + + # Get instrument names + insts = ( + db.query(Instrument) + .filter(Instrument.id.in_([p.instrumentid for p in points])) + .all() + ) + + inst_set = list(set([i.instrument_name for i in insts])) + + # Create DOI or use existing URL + if doi_url: + doi_id, doi_url = 0, doi_url + else: + doi_id, doi_url = create_pointing_doi( + points, normalized_graceid, creators, inst_set + ) + + # Update pointings with DOI information + for p in points: + p.doi_url = doi_url + p.doi_id = doi_id + p.submitterid = current_user.id # Ensure submitter is set + + db.commit() + + return doi_url diff --git a/server/routes/ui/resend_verification_email.py b/server/routes/ui/resend_verification_email.py new file mode 100644 index 00000000..0b80e795 --- /dev/null +++ b/server/routes/ui/resend_verification_email.py @@ -0,0 +1,54 @@ +"""Resend verification email endpoint.""" + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.users import Users +from server.auth.auth import get_current_user + +router = APIRouter(tags=["UI"]) + + +@router.post("/ajax_resend_verification_email") +async def resend_verification_email( + email: str = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """Resend the verification email to a user.""" + from server.utils.email import send_account_validation_email + + # If email is provided, find that user (admin function) + # Otherwise use the current user + if email: + user = db.query(Users).filter(Users.email == email).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Only allow admins to send verification emails to other users + # Note: Need to check if user is admin + from server.db.models.users import UserGroups, Groups + + admin_group = db.query(Groups).filter(Groups.name == "admin").first() + if admin_group: + user_group = ( + db.query(UserGroups) + .filter( + UserGroups.userid == current_user.id, + UserGroups.groupid == admin_group.id, + ) + .first() + ) + if not user_group: + raise HTTPException(status_code=403, detail="Not authorized") + else: + user = current_user + + if user.verified: + return {"message": "User is already verified"} + + # Send the verification email + send_account_validation_email(user, db) + + return {"message": "Verification email has been resent"} diff --git a/server/routes/ui/router.py b/server/routes/ui/router.py new file mode 100644 index 00000000..604329d3 --- /dev/null +++ b/server/routes/ui/router.py @@ -0,0 +1,36 @@ +"""Consolidated router for all UI endpoints.""" + +from fastapi import APIRouter + +# Import all individual route modules +from .alert_instruments_footprints import router as alert_instruments_footprints_router +from .preview_footprint import router as preview_footprint_router +from .resend_verification_email import router as resend_verification_email_router +from .coverage_calculator import router as coverage_calculator_router +from .spectral_range_from_bands import router as spectral_range_from_bands_router +from .pointing_from_id import router as pointing_from_id_router +from .grade_calculator import router as grade_calculator_router +from .icecube_notice import router as icecube_notice_router +from .event_galaxies import router as event_galaxies_router +from .scimma_xrt import router as scimma_xrt_router +from .candidate_fetch import router as candidate_fetch_router +from .request_doi import router as request_doi_router +from .alert_type import router as alert_type_router + +# Create the main router that includes all UI routes +router = APIRouter(tags=["UI"]) + +# Include all the individual routers +router.include_router(alert_instruments_footprints_router) +router.include_router(preview_footprint_router) +router.include_router(resend_verification_email_router) +router.include_router(coverage_calculator_router) +router.include_router(spectral_range_from_bands_router) +router.include_router(pointing_from_id_router) +router.include_router(grade_calculator_router) +router.include_router(icecube_notice_router) +router.include_router(event_galaxies_router) +router.include_router(scimma_xrt_router) +router.include_router(candidate_fetch_router) +router.include_router(request_doi_router) +router.include_router(alert_type_router) diff --git a/server/routes/ui/scimma_xrt.py b/server/routes/ui/scimma_xrt.py new file mode 100644 index 00000000..52b3cac7 --- /dev/null +++ b/server/routes/ui/scimma_xrt.py @@ -0,0 +1,65 @@ +"""SCIMMA XRT endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db +from server.db.models.gw_alert import GWAlert + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_scimma_xrt") +async def ajax_scimma_xrt(graceid: str, db: Session = Depends(get_db)): + """Get SCIMMA XRT sources associated with a GW event.""" + import requests + import urllib.parse + from server.utils.function import sanatize_XRT_source_info + + # Normalize the graceid - maintain backward compatibility + normalized_graceid = GWAlert.graceidfromalternate(graceid) + + # Special case for S190426 + if "S190426" in normalized_graceid: + normalized_graceid = "S190426" + + # Prepare query parameters + keywords = { + "keyword": "", + "cone_search": "", + "polygon_search": "", + "alert_timestamp_after": "", + "alert_timestamp_before": "", + "role": "", + "event_trigger_number": normalized_graceid, + "ordering": "", + "page_size": 1000, + } + + # Construct URL and make request + base_url = "http://skip.dev.hop.scimma.org/api/alerts/" + url = f"{base_url}?{urllib.parse.urlencode(keywords)}" + + markers = [] + payload = [] + + try: + response = requests.get(url) + if response.status_code == 200: + package = response.json()["results"] + for p in package: + markers.append( + { + "name": p["alert_identifier"], + "ra": p["right_ascension"], + "dec": p["declination"], + "info": sanatize_XRT_source_info(p), + } + ) + except Exception as e: + print(f"Error fetching SCIMMA XRT data: {str(e)}") + + if markers: + payload.append({"name": "SCIMMA XRT Sources", "color": "", "markers": markers}) + + return payload diff --git a/server/routes/ui/spectral_range_from_bands.py b/server/routes/ui/spectral_range_from_bands.py new file mode 100644 index 00000000..4deeb0b5 --- /dev/null +++ b/server/routes/ui/spectral_range_from_bands.py @@ -0,0 +1,79 @@ +"""Update spectral range from selected bands endpoint.""" + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from server.db.database import get_db + +router = APIRouter(tags=["UI"]) + + +@router.get("/ajax_update_spectral_range_from_selected_bands") +async def spectral_range_from_selected_bands( + band_cov: str, spectral_type: str, spectral_unit: str, db: Session = Depends(get_db) +): + """Calculate spectral range based on selected bands.""" + from server.core.enums.wavelengthunits import WavelengthUnits + from server.core.enums.energyunits import EnergyUnits + from server.core.enums.frequencyunits import FrequencyUnits + from server.core.enums.bandpass import Bandpass + + if not band_cov or band_cov == "null": + return {"total_min": "", "total_max": ""} + + # Split bands + bands = band_cov.split(",") + mins, maxs = [], [] + + for b in bands: + try: + # Find the bandpass enum value + band_enum = [x for x in Bandpass if b == x.name][0] + band_min, band_max = None, None + + # Handle different spectral types + if spectral_type == "wavelength": + # Get wavelength range for this band + from server.utils.spectral import wavetoWaveRange + + band_min, band_max = wavetoWaveRange(bandpass_enum=band_enum) + # Get the scale factor for the requested unit + # Handle unit name aliases (nm -> nanometer) + unit_name = spectral_unit + if spectral_unit == "nm": + unit_name = "nanometer" + unit = [x for x in WavelengthUnits if unit_name == x.name][0] + scale = WavelengthUnits.get_scale(unit) + + elif spectral_type == "energy": + # Get energy range for this band + from server.utils.spectral import wavetoEnergy + + band_min, band_max = wavetoEnergy(bandpass_enum=band_enum) + # Get the scale factor for the requested unit + unit = [x for x in EnergyUnits if spectral_unit == x.name][0] + scale = EnergyUnits.get_scale(unit) + + elif spectral_type == "frequency": + # Get frequency range for this band + from server.utils.spectral import wavetoFrequency + + band_min, band_max = wavetoFrequency(bandpass_enum=band_enum) + # Get the scale factor for the requested unit + unit = [x for x in FrequencyUnits if spectral_unit == x.name][0] + scale = FrequencyUnits.get_scale(unit) + + # If we got valid values, append them to our lists + if band_min is not None and band_max is not None: + mins.append(band_min / scale) + maxs.append(band_max / scale) + + except (IndexError, ValueError): + # Skip invalid bands + continue + + # Return the overall range + if mins: + return {"total_min": min(mins), "total_max": max(maxs)} + else: + return {"total_min": "", "total_max": ""} diff --git a/server/schemas/candidate.py b/server/schemas/candidate.py new file mode 100644 index 00000000..5256afa5 --- /dev/null +++ b/server/schemas/candidate.py @@ -0,0 +1,352 @@ +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing import Optional, List, Any, Literal, Dict, Union +from datetime import datetime +from geoalchemy2.types import WKBElement +from typing_extensions import Annotated +from server.core.enums.depthunit import DepthUnit as depth_unit_enum + + +class CandidateSchema(BaseModel): + id: int + graceid: str + submitterid: int + candidate_name: str + datecreated: Optional[datetime] = None + tns_name: Optional[str] = None + tns_url: Optional[str] = None + position: Annotated[str, WKBElement] = Field( + None, description="WKT representation of the position" + ) + discovery_date: Optional[datetime] = None + discovery_magnitude: Optional[float] = None + magnitude_central_wave: Optional[float] = None + magnitude_bandwidth: Optional[float] = None + magnitude_unit: Optional[str] = None + magnitude_bandpass: Optional[str] = None + associated_galaxy: Optional[str] = None + associated_galaxy_redshift: Optional[float] = None + associated_galaxy_distance: Optional[float] = None + + model_config = ConfigDict(from_attributes=True) + + +class GetCandidateQueryParams(BaseModel): + id: Optional[int] = Field(None, description="Filter by candidate ID") + ids: Optional[List[int]] = Field( + None, description="Filter by a list of candidate IDs" + ) + graceid: Optional[str] = Field(None, description="Filter by Grace ID") + userid: Optional[int] = Field(None, description="Filter by user ID") + submitted_date_after: Optional[datetime] = Field( + None, description="Filter by submission date after this timestamp" + ) + submitted_date_before: Optional[datetime] = Field( + None, description="Filter by submission date before this timestamp" + ) + discovery_magnitude_gt: Optional[float] = Field( + None, description="Filter by discovery magnitude greater than this value" + ) + discovery_magnitude_lt: Optional[float] = Field( + None, description="Filter by discovery magnitude less than this value" + ) + discovery_date_after: Optional[datetime] = Field( + None, description="Filter by discovery date after this timestamp" + ) + discovery_date_before: Optional[datetime] = Field( + None, description="Filter by discovery date before this timestamp" + ) + associated_galaxy_name: Optional[str] = Field( + None, description="Filter by associated galaxy name" + ) + associated_galaxy_redshift_gt: Optional[float] = Field( + None, description="Filter by associated galaxy redshift greater than this value" + ) + associated_galaxy_redshift_lt: Optional[float] = Field( + None, description="Filter by associated galaxy redshift less than this value" + ) + associated_galaxy_distance_gt: Optional[float] = Field( + None, description="Filter by associated galaxy distance greater than this value" + ) + associated_galaxy_distance_lt: Optional[float] = Field( + None, description="Filter by associated galaxy distance less than this value" + ) + + +class CandidateRequest(BaseModel): + """Single candidate submission model""" + + candidate_name: str + position: Optional[str] = None + ra: Optional[float] = None + dec: Optional[float] = None + tns_name: Optional[str] = None + tns_url: Optional[str] = None + discovery_date: str + discovery_magnitude: float + magnitude_unit: Union[depth_unit_enum, str, int] + magnitude_bandpass: Optional[str] = None + magnitude_central_wave: Optional[float] = None + magnitude_bandwidth: Optional[float] = None + wavelength_regime: Optional[List[float]] = None + wavelength_unit: Optional[str] = None + frequency_regime: Optional[List[float]] = None + frequency_unit: Optional[str] = None + energy_regime: Optional[List[float]] = None + energy_unit: Optional[str] = None + associated_galaxy: Optional[str] = None + associated_galaxy_redshift: Optional[float] = None + associated_galaxy_distance: Optional[float] = None + + @field_validator("discovery_date") + def validate_discovery_date(cls, value): + try: + datetime.fromisoformat(value) + except ValueError: + raise ValueError( + "Invalid discovery_date format. Must be a valid ISO 8601 datetime string." + ) + return value + + @field_validator("magnitude_unit", mode="before") + @classmethod + def validate_magnitude_unit(cls, value): + """ + Validate magnitude unit, accepting both string and enum values. + + Args: + value: Input magnitude unit (string or enum) + + Returns: + depth_unit_enum: Validated enum value + + Raises: + ValueError if the input is not a valid magnitude unit + """ + if isinstance(value, depth_unit_enum): + return value + + if isinstance(value, str): + try: + # Try converting string to enum by name + return depth_unit_enum[value] + except KeyError: + # If name lookup fails, check if it can be converted from integer + try: + return depth_unit_enum(int(value)) + except (ValueError, TypeError): + raise ValueError( + f"Invalid magnitude unit: {value}. " + f"Must be one of {list(depth_unit_enum.__members__.keys())}" + ) + + if isinstance(value, int): + try: + return depth_unit_enum(value) + except ValueError: + raise ValueError(f"Invalid magnitude unit value: {value}") + + raise ValueError(f"Invalid magnitude unit type: {type(value)}") + + @model_validator(mode="after") + def validate_position_data(self): + ra_dec_provided = sum([self.ra is not None, self.dec is not None]) > 0 + position_provided = self.position is not None + if ra_dec_provided or position_provided: + return self + else: + raise ValueError("Either position or both ra and dec must be provided") + + model_config = ConfigDict( + from_attributes=True, + json_encoders={depth_unit_enum: lambda v: v.name if v else None}, + ) + + +class PostCandidateRequest(BaseModel): + """Main request model with either single candidate or multiple candidates""" + + graceid: str + candidate: Optional[CandidateRequest] = None + candidates: Optional[List[CandidateRequest]] = None + + @model_validator(mode="after") + def validate_exactly_one_candidate_field(self): + fields_provided = sum( + [ + self.candidate is not None, + self.candidates is not None and len(self.candidates) > 0, + ] + ) + + if fields_provided == 0: + raise ValueError("Must provide either 'candidate' or 'candidates'") + elif fields_provided > 1: + raise ValueError("Cannot provide both 'candidate' and 'candidates'") + + return self + + +class CandidateResponse(BaseModel): + """Response model matching the Flask API format""" + + candidate_ids: List[int] + ERRORS: List[List[Any]] + WARNINGS: List[List[Any]] + + +class CandidateUpdateField(BaseModel): + """Fields that can be updated for a candidate""" + + graceid: Optional[str] = None + candidate_name: Optional[str] = None + tns_name: Optional[str] = None + tns_url: Optional[str] = None + position: Optional[str] = None + ra: Optional[float] = None + dec: Optional[float] = None + discovery_date: Optional[str] = None + discovery_magnitude: Optional[float] = None + magnitude_central_wave: Optional[float] = None + magnitude_bandwidth: Optional[float] = None + magnitude_unit: Optional[str] = None + magnitude_bandpass: Optional[str] = None + associated_galaxy: Optional[str] = None + associated_galaxy_redshift: Optional[float] = None + associated_galaxy_distance: Optional[float] = None + wavelength_regime: Optional[List[float]] = None + wavelength_unit: Optional[str] = None + frequency_regime: Optional[List[float]] = None + frequency_unit: Optional[str] = None + energy_regime: Optional[List[float]] = None + energy_unit: Optional[str] = None + + @field_validator("discovery_date") + def validate_discovery_date(cls, value): + if value is not None: + try: + datetime.fromisoformat(value) + except ValueError: + raise ValueError( + "discovery_date must be a valid ISO 8601 datetime string" + ) + return value + + +class PutCandidateRequest(BaseModel): + """Request model for updating a candidate""" + + id: int + candidate: CandidateUpdateField + + +class PutCandidateSuccessResponse(BaseModel): + """Success response model""" + + message: Literal["success"] + candidate: Dict[str, Any] + + +class PutCandidateFailureResponse(BaseModel): + """Failure response model""" + + message: Literal["failure"] + errors: List[Any] + + +# Union type for response +PutCandidateResponse = Union[PutCandidateSuccessResponse, PutCandidateFailureResponse] + + +class DeleteCandidateResponse(BaseModel): + """Response model for successful delete operation""" + + message: str + deleted_ids: Optional[List[int]] = [] + warnings: Optional[List[str]] = [] + + +# You can add this to your candidate.py schemas file +class DeleteCandidateParams(BaseModel): + """ + Parameters for deleting candidates. + Either id or ids must be provided. + """ + + id: Optional[int] = None + ids: Optional[List[int]] = None + + +# Base schema with common fields +class GWCandidateBase(BaseModel): + """Base schema for GW candidate data.""" + + graceid: str = Field(..., description="Grace ID of the GW event") + candidate_name: str = Field(..., description="Name of the candidate") + submitterid: Optional[int] = Field( + None, description="ID of the user who submitted the candidate" + ) + datecreated: Optional[datetime] = Field( + None, description="Date when the candidate was created" + ) + tns_name: Optional[str] = Field(None, description="TNS name of the candidate") + tns_url: Optional[str] = Field(None, description="TNS URL of the candidate") + discovery_date: Optional[datetime] = Field(None, description="Date of discovery") + discovery_magnitude: Optional[float] = Field( + None, description="Magnitude at discovery" + ) + magnitude_central_wave: Optional[float] = Field( + None, description="Central wavelength for magnitude" + ) + magnitude_bandwidth: Optional[float] = Field( + None, description="Bandwidth for magnitude measurement" + ) + magnitude_unit: Optional[str] = Field( + None, description="Unit of magnitude measurement" + ) + magnitude_bandpass: Optional[str] = Field(None, description="Bandpass filter used") + associated_galaxy: Optional[str] = Field(None, description="Associated galaxy name") + associated_galaxy_redshift: Optional[float] = Field( + None, description="Redshift of associated galaxy" + ) + associated_galaxy_distance: Optional[float] = Field( + None, description="Distance to associated galaxy" + ) + + +# Request schema with validation +class GWCandidateCreate(GWCandidateBase): + """Schema for creating/updating candidates with coordinate validation.""" + + ra: Optional[float] = Field( + None, ge=0.0, le=360.0, description="Right ascension in degrees (0-360)" + ) + dec: Optional[float] = Field( + None, ge=-90.0, le=90.0, description="Declination in degrees (-90 to +90)" + ) + + @field_validator("ra") + @classmethod + def validate_ra(cls, v): + """Validate right ascension is within valid range.""" + if v is not None and (v < 0.0 or v > 360.0): + raise ValueError("Right ascension must be between 0 and 360 degrees") + return v + + @field_validator("dec") + @classmethod + def validate_dec(cls, v): + """Validate declination is within valid range.""" + if v is not None and (v < -90.0 or v > 90.0): + raise ValueError("Declination must be between -90 and +90 degrees") + return v + + +# Response schema without strict validation (for existing data) +class GWCandidateSchema(GWCandidateBase): + """Schema for returning candidates without strict coordinate validation.""" + + id: Optional[int] = Field(None, description="Unique identifier for the candidate") + ra: Optional[float] = Field(None, description="Right ascension in degrees") + dec: Optional[float] = Field(None, description="Declination in degrees") + + model_config = ConfigDict(from_attributes=True) diff --git a/server/schemas/doi.py b/server/schemas/doi.py new file mode 100644 index 00000000..75734373 --- /dev/null +++ b/server/schemas/doi.py @@ -0,0 +1,106 @@ +# server/schemas/doi.py + +from pydantic import BaseModel, ConfigDict, Field +from typing import List, Dict, Any, Optional +from datetime import datetime + + +class DOIAuthorBase(BaseModel): + """Base schema for DOI author data.""" + + name: str + affiliation: str + orcid: Optional[str] = None + gnd: Optional[str] = None + pos_order: Optional[int] = None + + +class DOIAuthorCreate(DOIAuthorBase): + """Schema for creating a new DOI author.""" + + author_groupid: int + + +class DOIAuthorSchema(DOIAuthorBase): + """Schema for returning a DOI author.""" + + id: int + author_groupid: int + + model_config = ConfigDict(from_attributes=True) + + +class DOIAuthorGroupBase(BaseModel): + """Base schema for DOI author group data.""" + + name: str + userid: Optional[int] = None + + +class DOIAuthorGroupCreate(DOIAuthorGroupBase): + """Schema for creating a new DOI author group.""" + + pass + + +class DOIAuthorGroupSchema(DOIAuthorGroupBase): + """Schema for returning a DOI author group.""" + + id: int + + model_config = ConfigDict(from_attributes=True) + + +class DOICreator(BaseModel): + """Schema for a DOI creator.""" + + name: str + affiliation: str + orcid: Optional[str] = None + gnd: Optional[str] = None + + +class DOIPointingInfo(BaseModel): + """Schema for DOI pointing information.""" + + id: int + graceid: str + instrument_name: str + status: str + doi_url: Optional[str] = None + doi_id: Optional[int] = None + + +class DOIPointingsResponse(BaseModel): + """Schema for DOI pointings response.""" + + pointings: List[DOIPointingInfo] + + +class DOIRequestResponse(BaseModel): + """Schema for DOI request response.""" + + DOI_URL: Optional[str] = None + WARNINGS: List[Any] = [] + + +class DOIMetadata(BaseModel): + """Schema for DOI metadata.""" + + doi: str + creators: List[DOICreator] + titles: List[Dict[str, str]] + publisher: str + publicationYear: str + resourceType: Dict[str, str] + descriptions: List[Dict[str, str]] + relatedIdentifiers: Optional[List[Dict[str, str]]] = None + + +class DOICreate(BaseModel): + """Schema for creating a new DOI.""" + + points: List[int] + graceid: str + creators: List[DOICreator] + reference: Optional[str] = None diff --git a/server/schemas/glade.py b/server/schemas/glade.py new file mode 100644 index 00000000..adcb8ccc --- /dev/null +++ b/server/schemas/glade.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel, ConfigDict, Field +from typing import Optional + + +class Glade2P3Schema(BaseModel): + pgc_number: int = Field(..., description="PGC number of the galaxy") + distance: float = Field(..., description="Distance of the galaxy") + position: Optional[dict] = Field( + None, + description="Geographical position as a dictionary with 'latitude' and 'longitude'", + ) + _2mass_name: Optional[str] = Field(None, description="2MASS name of the galaxy") + gwgc_name: Optional[str] = Field(None, description="GWGC name of the galaxy") + hyperleda_name: Optional[str] = Field( + None, description="HyperLEDA name of the galaxy" + ) + sdssdr12_name: Optional[str] = Field( + None, description="SDSS DR12 name of the galaxy" + ) + + model_config = ConfigDict(from_attributes=True) diff --git a/server/schemas/gw_alert.py b/server/schemas/gw_alert.py new file mode 100644 index 00000000..6bb07bea --- /dev/null +++ b/server/schemas/gw_alert.py @@ -0,0 +1,91 @@ +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing import Optional +from datetime import datetime + + +class GWAlertSchema(BaseModel): + id: Optional[int] = None + datecreated: Optional[datetime] = None + graceid: str = Field(..., description="Grace ID of the GW event") + alternateid: Optional[str] = None + role: str = Field(..., description="Role of the alert (observation, test, etc.)") + timesent: Optional[datetime] = None + time_of_signal: Optional[datetime] = None + packet_type: Optional[int] = None + alert_type: str = Field( + ..., description="Type of alert (Initial, Update, Retraction, etc.)" + ) + detectors: Optional[str] = None + description: Optional[str] = None + far: Optional[float] = None + skymap_fits_url: Optional[str] = None + distance: Optional[float] = None + distance_error: Optional[float] = None + prob_bns: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability of BNS" + ) + prob_nsbh: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability of NSBH" + ) + prob_gap: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability of mass gap" + ) + prob_bbh: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability of BBH" + ) + prob_terrestrial: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability of terrestrial" + ) + prob_hasns: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability has neutron star" + ) + prob_hasremenant: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Probability has remnant" + ) + group: Optional[str] = None + centralfreq: Optional[float] = None + duration: Optional[float] = None + avgra: Optional[float] = Field( + None, ge=0.0, le=360.0, description="Average right ascension" + ) + avgdec: Optional[float] = Field( + None, ge=-90.0, le=90.0, description="Average declination" + ) + observing_run: Optional[str] = None + pipeline: Optional[str] = None + search: Optional[str] = None + area_50: Optional[float] = Field(None, ge=0.0, description="50% confidence area") + area_90: Optional[float] = Field(None, ge=0.0, description="90% confidence area") + gcn_notice_id: Optional[int] = None + ivorn: Optional[str] = None + ext_coinc_observatory: Optional[str] = None + ext_coinc_search: Optional[str] = None + time_coincidence_far: Optional[float] = None + time_sky_position_coincidence_far: Optional[float] = None + time_difference: Optional[float] = None + + @field_validator("far") + @classmethod + def validate_far(cls, v): + """Validate False Alarm Rate is positive.""" + if v is not None and v < 0: + raise ValueError("False Alarm Rate must be positive") + return v + + @field_validator("distance") + @classmethod + def validate_distance(cls, v): + """Validate distance is positive.""" + if v is not None and v < 0: + raise ValueError("Distance must be positive") + return v + + @field_validator("distance_error") + @classmethod + def validate_distance_error(cls, v): + """Validate distance error is positive.""" + if v is not None and v < 0: + raise ValueError("Distance error must be positive") + return v + + model_config = ConfigDict(from_attributes=True) diff --git a/server/schemas/gw_galaxy.py b/server/schemas/gw_galaxy.py new file mode 100644 index 00000000..e3d95c33 --- /dev/null +++ b/server/schemas/gw_galaxy.py @@ -0,0 +1,160 @@ +from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator +from typing import Optional, List, Dict, Any, Union +from dateutil.parser import parse as date_parse +from server.core.enums.gwgalaxyscoretype import GwGalaxyScoreType + + +class GWGalaxySchema(BaseModel): + """Pydantic schema for GWGalaxy model.""" + + id: Optional[int] = None + graceid: str + galaxy_catalog: Optional[int] = None + galaxy_catalogid: Optional[int] = None + reference: Optional[str] = None + + model_config = ConfigDict(from_attributes=True) + + +class EventGalaxySchema(BaseModel): + """Pydantic schema for EventGalaxy model.""" + + id: Optional[int] = None + graceid: str + galaxy_catalog: Optional[int] = None + galaxy_catalogid: Optional[int] = None + + model_config = ConfigDict(from_attributes=True) + + +class GWGalaxyScoreSchema(BaseModel): + """Pydantic schema for GWGalaxyScore model.""" + + id: Optional[int] = None + gw_galaxyid: int + score_type: Optional[GwGalaxyScoreType] = None + score: Optional[float] = None + + model_config = ConfigDict(from_attributes=True) + + +class GWGalaxyListSchema(BaseModel): + """Pydantic schema for GWGalaxyList model.""" + + id: Optional[int] = None + graceid: str + groupname: str + submitterid: Optional[int] = None + reference: Optional[str] = None + alertid: Optional[str] = None + doi_url: Optional[str] = None + doi_id: Optional[int] = None + + model_config = ConfigDict(from_attributes=True) + + +class GWGalaxyEntrySchema(BaseModel): + """Pydantic schema for GWGalaxyEntry model.""" + + id: Optional[int] = None + listid: int + name: str + score: float + position: Optional[str] = None + rank: int + info: Optional[Dict[str, Any]] = None + + model_config = ConfigDict(from_attributes=True) + + +class GalaxyPosition(BaseModel): + """Schema for position data using RA and Dec.""" + + ra: float = Field(..., description="Right ascension in degrees") + dec: float = Field(..., description="Declination in degrees") + + +class GalaxyEntryCreate(BaseModel): + """Schema for creating a new GWGalaxyEntry.""" + + name: str = Field(..., description="Name of the galaxy") + score: float = Field(..., description="Score or probability value for this galaxy") + position: Optional[str] = Field( + None, description="WKT representation of position (e.g., 'POINT(10.5 -20.3)')" + ) + ra: Optional[float] = Field(None, description="Right ascension in degrees") + dec: Optional[float] = Field(None, description="Declination in degrees") + rank: int = Field(..., description="Rank of this galaxy in the list") + info: Optional[Dict[str, Any]] = Field( + None, description="Additional information about the galaxy" + ) + + @model_validator(mode="after") + def check_position_data(self) -> "GalaxyEntryCreate": + """Validate that either position or ra/dec are provided.""" + position = self.position + ra = self.ra + dec = self.dec + + # If position string is provided, validate it's in correct format + if position is not None: + if ( + not all(x in position for x in ["POINT", "(", ")", " "]) + or "," in position + ): + raise ValueError("Position must be in WKT format: 'POINT(lon lat)'") + return self + + # If no position string, ra and dec must both be provided + if ra is None or dec is None: + raise ValueError( + "Either position string or both ra and dec must be provided" + ) + + return self + + +class DOICreator(BaseModel): + """Schema for a DOI creator/author.""" + + name: str = Field(..., description="Author name") + affiliation: str = Field(..., description="Author affiliation") + orcid: Optional[str] = Field(None, description="ORCID identifier") + gnd: Optional[str] = Field(None, description="GND identifier") + + +class PostEventGalaxiesRequest(BaseModel): + """Schema for posting galaxy entries for a GW event.""" + + graceid: str = Field(..., description="Grace ID of the GW event") + timesent_stamp: str = Field(..., description="Timestamp of the event in ISO format") + groupname: Optional[str] = Field(None, description="Group name for the galaxy list") + reference: Optional[str] = Field(None, description="Reference for the galaxy list") + request_doi: Optional[bool] = Field(False, description="Whether to request a DOI") + creators: Optional[List[DOICreator]] = Field( + None, description="List of creators with name and affiliation" + ) + doi_group_id: Optional[Union[int, str]] = Field( + None, description="ID or name of the DOI group" + ) + galaxies: List[GalaxyEntryCreate] = Field(..., description="List of galaxy entries") + + @field_validator("timesent_stamp") + @classmethod + def validate_timestamp(cls, v: str) -> str: + """Validate that the timestamp is in a valid ISO format.""" + try: + date_parse(v) + return v + except Exception: + raise ValueError( + "Time format must be %Y-%m-%dT%H:%M:%S.%f, e.g. 2019-05-01T12:00:00.00" + ) + + +class PostEventGalaxiesResponse(BaseModel): + """Schema for the response when posting galaxy entries.""" + + message: str = Field(..., description="Success message") + errors: List[Any] = Field(..., description="List of errors encountered") + warnings: List[Any] = Field(..., description="List of warnings encountered") diff --git a/server/schemas/icecube.py b/server/schemas/icecube.py new file mode 100644 index 00000000..71b7281b --- /dev/null +++ b/server/schemas/icecube.py @@ -0,0 +1,153 @@ +from pydantic import BaseModel, ConfigDict, Field +from typing import Optional, List +from datetime import datetime + + +class IceCubeNoticeCreateSchema(BaseModel): + """Schema for creating a new IceCube notice.""" + + ref_id: str = Field(..., description="Reference ID for the IceCube notice") + graceid: str = Field(..., description="Grace ID of the associated GW event") + alert_datetime: Optional[datetime] = Field( + None, description="Date and time of the alert" + ) + observation_start: Optional[datetime] = Field( + None, description="Start time of the observation period" + ) + observation_stop: Optional[datetime] = Field( + None, description="End time of the observation period" + ) + pval_generic: Optional[float] = Field( + None, description="Generic p-value for the event" + ) + pval_bayesian: Optional[float] = Field( + None, description="Bayesian p-value for the event" + ) + most_probable_direction_ra: Optional[float] = Field( + None, description="Right ascension of most probable direction" + ) + most_probable_direction_dec: Optional[float] = Field( + None, description="Declination of most probable direction" + ) + flux_sens_low: Optional[float] = Field( + None, description="Lower bound of flux sensitivity" + ) + flux_sens_high: Optional[float] = Field( + None, description="Upper bound of flux sensitivity" + ) + sens_energy_range_low: Optional[float] = Field( + None, description="Lower bound of sensitivity energy range" + ) + sens_energy_range_high: Optional[float] = Field( + None, description="Upper bound of sensitivity energy range" + ) + + +class IceCubeNoticeCoincEventCreateSchema(BaseModel): + """Schema for creating a new IceCube notice coincident event.""" + + event_dt: Optional[float] = Field(None, description="Event time difference") + ra: Optional[float] = Field(None, description="Right ascension of the event") + dec: Optional[float] = Field(None, description="Declination of the event") + containment_probability: Optional[float] = Field( + None, description="Probability of event containment" + ) + event_pval_generic: Optional[float] = Field( + None, description="Generic p-value for this specific event" + ) + event_pval_bayesian: Optional[float] = Field( + None, description="Bayesian p-value for this specific event" + ) + ra_uncertainty: Optional[float] = Field( + None, description="Uncertainty in right ascension" + ) + uncertainty_shape: Optional[str] = Field( + None, description="Shape of the uncertainty region" + ) + + +class IceCubeNoticeRequestSchema(BaseModel): + """Schema for submitting an IceCube notice with associated events.""" + + notice_data: IceCubeNoticeCreateSchema = Field(..., description="Main notice data") + events_data: List[IceCubeNoticeCoincEventCreateSchema] = Field( + ..., description="List of coincident events" + ) + + +class IceCubeNoticeSchema(BaseModel): + """Schema for returning an IceCube notice.""" + + id: int = Field(..., description="Unique identifier for the notice") + ref_id: str = Field(..., description="Reference ID for the IceCube notice") + graceid: str = Field(..., description="Grace ID of the associated GW event") + alert_datetime: Optional[datetime] = Field( + None, description="Date and time of the alert" + ) + datecreated: Optional[datetime] = Field( + None, description="Date when the notice was created" + ) + observation_start: Optional[datetime] = Field( + None, description="Start time of the observation period" + ) + observation_stop: Optional[datetime] = Field( + None, description="End time of the observation period" + ) + pval_generic: Optional[float] = Field( + None, description="Generic p-value for the event" + ) + pval_bayesian: Optional[float] = Field( + None, description="Bayesian p-value for the event" + ) + most_probable_direction_ra: Optional[float] = Field( + None, description="Right ascension of most probable direction" + ) + most_probable_direction_dec: Optional[float] = Field( + None, description="Declination of most probable direction" + ) + flux_sens_low: Optional[float] = Field( + None, description="Lower bound of flux sensitivity" + ) + flux_sens_high: Optional[float] = Field( + None, description="Upper bound of flux sensitivity" + ) + sens_energy_range_low: Optional[float] = Field( + None, description="Lower bound of sensitivity energy range" + ) + sens_energy_range_high: Optional[float] = Field( + None, description="Upper bound of sensitivity energy range" + ) + + model_config = ConfigDict(from_attributes=True) + + +class IceCubeNoticeCoincEventSchema(BaseModel): + """Schema for returning an IceCube notice coincident event.""" + + id: int = Field(..., description="Unique identifier for the event") + icecube_notice_id: int = Field( + ..., description="ID of the associated IceCube notice" + ) + datecreated: Optional[datetime] = Field( + None, description="Date when the event was created" + ) + event_dt: Optional[float] = Field(None, description="Event time difference") + ra: Optional[float] = Field(None, description="Right ascension of the event") + dec: Optional[float] = Field(None, description="Declination of the event") + containment_probability: Optional[float] = Field( + None, description="Probability of event containment" + ) + event_pval_generic: Optional[float] = Field( + None, description="Generic p-value for this specific event" + ) + event_pval_bayesian: Optional[float] = Field( + None, description="Bayesian p-value for this specific event" + ) + ra_uncertainty: Optional[float] = Field( + None, description="Uncertainty in right ascension" + ) + uncertainty_shape: Optional[str] = Field( + None, description="Shape of the uncertainty region" + ) + + model_config = ConfigDict(from_attributes=True) diff --git a/server/schemas/instrument.py b/server/schemas/instrument.py new file mode 100644 index 00000000..1a20d1ae --- /dev/null +++ b/server/schemas/instrument.py @@ -0,0 +1,87 @@ +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing import Optional, List, Tuple +from datetime import datetime +from server.core.enums.instrumenttype import InstrumentType as instrument_type_enum + + +class InstrumentSchema(BaseModel): + """Schema for returning an instrument.""" + + id: int = Field(..., description="Unique identifier for the instrument") + instrument_name: str = Field(..., description="Name of the instrument") + nickname: Optional[str] = Field( + None, description="Nickname or short name for the instrument" + ) + instrument_type: instrument_type_enum = Field( + ..., description="Type of the instrument" + ) + datecreated: Optional[datetime] = Field( + None, description="Date when the instrument was created" + ) + submitterid: Optional[int] = Field( + None, description="ID of the user who submitted the instrument" + ) + + model_config = ConfigDict(from_attributes=True, use_enum_values=False) + + @field_validator("instrument_type", mode="before") + def serialize_enum(cls, value): + if isinstance(value, instrument_type_enum): + return value.value # Convert enum to its name (e.g., "photometric") + return value + + +class InstrumentCreate(BaseModel): + """Schema for creating a new instrument.""" + + instrument_name: str = Field(..., description="Name of the instrument") + nickname: Optional[str] = Field( + None, description="Nickname or short name for the instrument" + ) + instrument_type: instrument_type_enum = Field( + ..., description="Type of the instrument" + ) + + +class InstrumentUpdate(BaseModel): + """Schema for updating an instrument.""" + + instrument_name: Optional[str] = Field( + None, description="Updated name of the instrument" + ) + nickname: Optional[str] = Field( + None, description="Updated nickname or short name for the instrument" + ) + instrument_type: Optional[instrument_type_enum] = Field( + None, description="Updated type of the instrument" + ) + + +class FootprintCCDSchema(BaseModel): + """Schema for returning a footprint CCD.""" + + id: int = Field(..., description="Unique identifier for the footprint") + instrumentid: int = Field(..., description="ID of the associated instrument") + footprint: Optional[str] = Field( + None, description="WKT representation of the footprint" + ) + + model_config = ConfigDict(from_attributes=True) + + +class FootprintCCDCreate(BaseModel): + """Schema for creating a new footprint CCD.""" + + instrumentid: int = Field(..., description="ID of the associated instrument") + footprint: str = Field(..., description="WKT representation of the footprint") + + @field_validator("footprint") + def validate_footprint(cls, value): + if value: + try: + from shapely.wkt import loads + + loads(value) # Attempt to parse the WKT + except Exception as e: + raise ValueError(f"Invalid WKT format: {e}") + return value diff --git a/server/schemas/pointing.py b/server/schemas/pointing.py new file mode 100644 index 00000000..ef381e34 --- /dev/null +++ b/server/schemas/pointing.py @@ -0,0 +1,346 @@ +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, + field_validator, + field_serializer, +) +from typing import List, Dict, Any, Optional, Union +from datetime import datetime + +from server.core.enums import Bandpass as bandpass_enum +from server.core.enums import DepthUnit as depth_unit_enum +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum + + +class PointingBase(BaseModel): + """Base schema for pointing data.""" + + position: Optional[Union[str, Any]] = None # Can be WKBElement or string + ra: Optional[float] = None + dec: Optional[float] = None + instrumentid: Optional[int] = None + depth: Optional[float] = None + depth_err: Optional[float] = None + depth_unit: Optional[Union[depth_unit_enum, str, int]] = None + band: Optional[Union[bandpass_enum, str, int]] = None + pos_angle: Optional[float] = None + time: Optional[datetime] = None + status: Optional[Union[pointing_status_enum, str, int]] = "completed" + central_wave: Optional[float] = None + bandwidth: Optional[float] = None + + # Serializers to convert enum values to string names + @field_serializer("status") + def serialize_status(self, status): + """Convert enum to string for JSON response.""" + if isinstance(status, pointing_status_enum): + return status.name + return status + + @field_serializer("depth_unit") + def serialize_depth_unit(self, depth_unit): + """Convert enum to string for JSON response.""" + if isinstance(depth_unit, depth_unit_enum): + return depth_unit.name + return depth_unit + + @field_serializer("band") + def serialize_band(self, band): + """Convert enum to string for JSON response.""" + if isinstance(band, bandpass_enum): + return band.name + return band + + @field_serializer("position") + def serialize_position(self, position): + """Convert WKBElement to string for JSON response.""" + if position and hasattr(position, "data"): + import shapely.wkb + + try: + geom = shapely.wkb.loads(bytes(position.data)) + return str(geom) + except Exception: + pass + return position + + # Validators to convert string values to enum values + @field_validator("status", mode="before") + @classmethod + def validate_status(cls, value): + if isinstance(value, str): + try: + return pointing_status_enum[value] + except KeyError: + # Return the string value if it's not a valid enum + return value + return value + + @field_validator("depth_unit", mode="before") + @classmethod + def validate_depth_unit(cls, value): + if isinstance(value, str): + try: + return depth_unit_enum[value] + except KeyError: + # Return the string value if it's not a valid enum + return value + elif isinstance(value, int): + try: + return depth_unit_enum(value) + except ValueError: + # Return the int value if it's not a valid enum + return value + return value + + @field_validator("band", mode="before") + @classmethod + def validate_band(cls, value): + if isinstance(value, str): + try: + return bandpass_enum[value] + except KeyError: + # Return the string value if it's not a valid enum + return value + elif isinstance(value, int): + try: + return bandpass_enum(value) + except ValueError: + # Return the int value if it's not a valid enum + return value + return value + + model_config = ConfigDict( + from_attributes=True, + arbitrary_types_allowed=True, + json_encoders={ + pointing_status_enum: lambda v: v.name if v else None, + depth_unit_enum: lambda v: v.name if v else None, + bandpass_enum: lambda v: v.name if v else None, + datetime: lambda v: v.isoformat() if v else None, + }, + ) + + +class PointingResponse(BaseModel): + """Schema for pointing creation response.""" + + pointing_ids: List[int] + ERRORS: List[Any] = [] + WARNINGS: List[Any] = [] + DOI: Optional[str] = None + + model_config = ConfigDict(from_attributes=True) + + +class PointingSchema(PointingBase): + """Schema for returning a pointing.""" + + id: Optional[int] = None + submitterid: Optional[int] = None + datecreated: Optional[datetime] = None + dateupdated: Optional[datetime] = None + doi_url: Optional[str] = None + doi_id: Optional[int] = None + + model_config = ConfigDict( + from_attributes=True, + arbitrary_types_allowed=True, + json_encoders={ + pointing_status_enum: lambda v: v.name if v else None, + depth_unit_enum: lambda v: v.name if v else None, + bandpass_enum: lambda v: v.name if v else None, + datetime: lambda v: v.isoformat() if v else None, + }, + ) + + +class PointingCreate(PointingBase): + """Schema for creating a new pointing with comprehensive validation.""" + + id: Optional[int] = Field(None, description="ID for updating an existing pointing") + + @model_validator(mode="after") + def validate_pointing_data(self): + """Comprehensive validation for pointing creation or update.""" + errors = [] + is_update = self.id is not None + + # For updates, we're more lenient with validation + if not is_update: + # Full validation for new pointings + + # Validate required fields based on status + if self.status == pointing_status_enum.completed: + if not self.depth: + errors.append("depth is required for completed observations") + if not self.depth_unit: + errors.append("depth_unit is required for completed observations") + if not self.band: + errors.append("band is required for completed observations") + if not self.time: + errors.append("time is required for completed observations") + + elif self.status == pointing_status_enum.planned: + if not self.time: + errors.append("time is required for planned observations") + + # Validate position (ra/dec or position string) - required for new pointings + if not self.position and not (self.ra is not None and self.dec is not None): + errors.append( + "Position information required (either position string or ra/dec coordinates)" + ) + + else: + # For updates, only validate fields if they're being changed to completed status + if self.status == pointing_status_enum.completed: + # Only require these fields if they're not already set in the database + # The service layer will handle checking existing values + pass + + # Validate position format if provided as string (for both new and updates) + if self.position and not ( + self.position + and all(x in self.position for x in ["POINT", "(", ")", " "]) + and "," not in self.position + ): + errors.append( + 'Invalid position argument. Must be decimal format ra/RA, dec/DEC, or geometry type "POINT(RA DEC)"' + ) + + # Convert ra/dec to position if provided (for both new and updates) + if self.ra is not None and self.dec is not None: + if not isinstance(self.ra, (int, float)) or not isinstance( + self.dec, (int, float) + ): + errors.append( + "Invalid position argument. Must be decimal format ra/RA, dec/DEC" + ) + else: + self.position = f"POINT({self.ra} {self.dec})" + + # Validate numeric fields (for both new and updates) + if self.depth is not None and not isinstance(self.depth, (int, float)): + errors.append("Invalid depth. Must be decimal") + + if self.depth_err is not None and not isinstance(self.depth_err, (int, float)): + errors.append("Invalid depth_err. Must be decimal") + + if self.pos_angle is not None and not isinstance(self.pos_angle, (int, float)): + errors.append("Invalid pos_angle. Must be decimal") + + if errors: + raise ValueError("; ".join(errors)) + + return self + + +class PointingCreateRequest(BaseModel): + """Schema for the complete pointing creation request.""" + + graceid: str = Field(..., description="Grace ID of the GW event") + pointing: Optional[PointingCreate] = Field( + None, description="Single pointing object" + ) + pointings: Optional[List[PointingCreate]] = Field( + None, description="List of pointing objects" + ) + request_doi: Optional[bool] = Field(False, description="Whether to request a DOI") + creators: Optional[List[Dict[str, str]]] = Field( + None, description="List of creators for the DOI" + ) + doi_group_id: Optional[int] = Field(None, description="DOI author group ID") + doi_url: Optional[str] = Field( + None, description="Optional DOI URL if already exists" + ) + + @model_validator(mode="after") + def validate_request(self): + """Validate the request has either pointing or pointings.""" + if not self.pointing and not self.pointings: + raise ValueError("Either pointing or pointings must be provided") + + if self.pointing and self.pointings: + raise ValueError("Cannot provide both pointing and pointings") + + # Validate DOI creators if request_doi is True + if self.request_doi and self.creators: + for creator in self.creators: + if "name" not in creator or "affiliation" not in creator: + raise ValueError( + "name and affiliation are required for each creator in the list" + ) + + return self + + +class PointingUpdate(BaseModel): + """Schema for updating a pointing.""" + + status: Union[pointing_status_enum, str] = Field( + ..., description="New status for the pointings" + ) + ids: List[int] = Field(..., description="List of pointing IDs to update") + + @field_validator("status", mode="before") + @classmethod + def validate_status(cls, value): + if isinstance(value, str): + try: + return pointing_status_enum[value] + except KeyError: + raise ValueError( + f"Invalid status: {value}. Valid values are: {[s.name for s in pointing_status_enum]}" + ) + return value + + @model_validator(mode="after") + def validate_update(self): + """Validate update request.""" + if not self.ids: + raise ValueError("At least one pointing ID must be provided") + + # Currently only support cancelling + if self.status != pointing_status_enum.cancelled: + raise ValueError("Only 'cancelled' status updates are currently supported") + + return self + + +class CancelAllRequest(BaseModel): + """Schema for cancelling all pointings.""" + + graceid: str = Field(..., description="Grace ID of the GW event") + instrumentid: int = Field(..., description="Instrument ID to cancel pointings for") + + +class DOIRequest(BaseModel): + """Schema for requesting a DOI.""" + + graceid: Optional[str] = Field(None, description="Grace ID of the GW event") + id: Optional[int] = Field(None, description="Pointing ID") + ids: Optional[List[int]] = Field(None, description="List of pointing IDs") + doi_group_id: Optional[str] = Field(None, description="DOI author group ID") + creators: Optional[List[Dict[str, str]]] = Field( + None, description="List of creators for the DOI" + ) + doi_url: Optional[str] = Field( + None, description="Optional DOI URL if already exists" + ) + + @model_validator(mode="after") + def validate_doi_request(self): + """Validate DOI request parameters.""" + if not self.graceid and not self.id and not self.ids: + raise ValueError("Please provide either graceid, id, or ids parameter") + + if self.creators: + for creator in self.creators: + if "name" not in creator or "affiliation" not in creator: + raise ValueError( + "name and affiliation are required for each creator in the list" + ) + + return self diff --git a/server/schemas/users.py b/server/schemas/users.py new file mode 100644 index 00000000..9b91f24d --- /dev/null +++ b/server/schemas/users.py @@ -0,0 +1,61 @@ +from pydantic import BaseModel, ConfigDict, Field +from typing import Optional, Dict, Any +from datetime import datetime + + +class UserSchema(BaseModel): + """Schema for returning a user.""" + + id: int = Field(..., description="Unique identifier for the user") + username: str = Field(..., description="Username of the user") + firstname: Optional[str] = Field(None, description="First name of the user") + lastname: Optional[str] = Field(None, description="Last name of the user") + email: Optional[str] = Field(None, description="Email address of the user") + datecreated: Optional[datetime] = Field( + None, description="Date when the user account was created" + ) + + model_config = ConfigDict(from_attributes=True) + + +class UserGroupSchema(BaseModel): + """Schema for returning a user group association.""" + + id: int = Field(..., description="Unique identifier for the user group association") + userid: int = Field(..., description="ID of the user") + groupid: int = Field(..., description="ID of the group") + role: Optional[str] = Field(None, description="Role of the user within the group") + + model_config = ConfigDict(from_attributes=True) + + +class GroupSchema(BaseModel): + """Schema for returning a group.""" + + id: int = Field(..., description="Unique identifier for the group") + name: str = Field(..., description="Name of the group") + datecreated: Optional[datetime] = Field( + None, description="Date when the group was created" + ) + + model_config = ConfigDict(from_attributes=True) + + +class UserActionSchema(BaseModel): + """Schema for returning a user action log entry.""" + + id: int = Field(..., description="Unique identifier for the user action") + userid: int = Field(..., description="ID of the user who performed the action") + ipaddress: Optional[str] = Field( + None, description="IP address from which the action was performed" + ) + url: Optional[str] = Field(None, description="URL that was accessed") + time: Optional[datetime] = Field( + None, description="Time when the action was performed" + ) + jsonvals: Optional[Dict[str, Any]] = Field( + None, description="Additional JSON data for the action" + ) + method: Optional[str] = Field(None, description="HTTP method used for the action") + + model_config = ConfigDict(from_attributes=True) diff --git a/server/services/__init__.py b/server/services/__init__.py new file mode 100644 index 00000000..a70b3029 --- /dev/null +++ b/server/services/__init__.py @@ -0,0 +1 @@ +# Services package diff --git a/server/services/pointing_service.py b/server/services/pointing_service.py new file mode 100644 index 00000000..eb0531cb --- /dev/null +++ b/server/services/pointing_service.py @@ -0,0 +1,203 @@ +""" +Pointing business logic services. +Contains logic that requires database access and can't be moved to Pydantic validators. +""" + +from sqlalchemy.orm import Session +from typing import List, Optional, Tuple, TYPE_CHECKING +from datetime import datetime + +if TYPE_CHECKING: + from server.schemas.pointing import PointingCreate + +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.db.models.users import Users +from server.db.models.doi_author import DOIAuthor +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.utils.error_handling import validation_exception, not_found_exception +from server.utils.function import pointing_crossmatch, create_pointing_doi + + +class PointingService: + """Service class for pointing business logic.""" + + @staticmethod + def validate_graceid(graceid: str, db: Session) -> str: + """Validate that a graceid exists in the database.""" + valid_alerts = db.query(GWAlert).filter(GWAlert.graceid == graceid).all() + if len(valid_alerts) == 0: + raise validation_exception( + message="Invalid graceid", + errors=[f"The graceid '{graceid}' does not exist in the database"], + ) + return graceid + + @staticmethod + def validate_instrument(instrument_id: int, db: Session) -> Instrument: + """Validate that an instrument exists.""" + instrument = db.query(Instrument).filter(Instrument.id == instrument_id).first() + if not instrument: + raise not_found_exception(f"Instrument with ID {instrument_id} not found") + return instrument + + @staticmethod + def get_instruments_dict(db: Session) -> dict: + """Get a dictionary of instruments for validation.""" + dbinsts = db.query(Instrument.instrument_name, Instrument.id).all() + return {inst.id: inst.instrument_name for inst in dbinsts} + + @staticmethod + def validate_instrument_reference( + pointing_data: "PointingCreate", instruments_dict: dict + ) -> int: + """Validate and resolve instrument reference to ID.""" + if pointing_data.instrumentid is None: + raise validation_exception( + message="Field instrumentid is required", + errors=["instrumentid must be provided"], + ) + + inst = pointing_data.instrumentid + if isinstance(inst, int): + if inst not in instruments_dict: + raise validation_exception( + message="Invalid instrumentid", + errors=[f"Instrument with ID {inst} not found"], + ) + return inst + else: + # Handle string instrument names + for inst_id, inst_name in instruments_dict.items(): + if inst_name == inst: + return inst_id + raise validation_exception( + message="Invalid instrumentid", + errors=[f"Instrument '{inst}' not found"], + ) + + @staticmethod + def check_duplicate_pointing( + pointing: Pointing, existing_pointings: List[Pointing] + ) -> bool: + """Check if a pointing is a duplicate of existing pointings.""" + return pointing_crossmatch(pointing, existing_pointings) + + @staticmethod + def create_pointing_from_schema( + pointing_data: "PointingCreate", user_id: int, instrument_id: int + ) -> Pointing: + """Create a Pointing model instance from schema data.""" + pointing = Pointing() + + # Set basic fields + pointing.position = pointing_data.position + pointing.depth = pointing_data.depth + pointing.depth_err = pointing_data.depth_err + pointing.depth_unit = pointing_data.depth_unit + pointing.status = pointing_data.status or pointing_status_enum.completed + pointing.band = pointing_data.band + pointing.central_wave = pointing_data.central_wave + pointing.bandwidth = pointing_data.bandwidth + pointing.instrumentid = instrument_id + pointing.pos_angle = pointing_data.pos_angle + pointing.time = pointing_data.time + pointing.submitterid = user_id + pointing.datecreated = datetime.now() + + return pointing + + @staticmethod + def handle_planned_pointing_update( + pointing_data, user_id: int, db: Session + ) -> Optional[Pointing]: + """Handle updating a planned pointing to completed.""" + if not hasattr(pointing_data, "id") or pointing_data.id is None: + return None + + pointing_id = int(pointing_data.id) + + # Find the planned pointing + planned_pointing = ( + db.query(Pointing) + .filter(Pointing.id == pointing_id, Pointing.submitterid == user_id) + .first() + ) + + if not planned_pointing: + raise validation_exception( + message="Pointing validation error", + errors=[ + f"Pointing with ID {pointing_id} not found or not owned by you" + ], + ) + + if planned_pointing.status in [ + pointing_status_enum.completed, + pointing_status_enum.cancelled, + ]: + raise validation_exception( + message="Pointing validation error", + errors=[ + f"This pointing has already been {planned_pointing.status.name}" + ], + ) + + # Update planned pointing with new data + if pointing_data.time: + planned_pointing.time = pointing_data.time + if pointing_data.pos_angle is not None: + planned_pointing.pos_angle = pointing_data.pos_angle + + planned_pointing.status = pointing_status_enum.completed + planned_pointing.dateupdated = datetime.now() + + return planned_pointing + + @staticmethod + def prepare_doi_creators( + creators: Optional[List[dict]], + doi_group_id: Optional[int], + user: Users, + db: Session, + ) -> List[dict]: + """Prepare DOI creators list.""" + if creators: + return creators + elif doi_group_id: + valid, creators_list = DOIAuthor.construct_creators( + doi_group_id, user.id, db + ) + if not valid: + raise validation_exception( + message="Invalid DOI group ID", + errors=["Make sure you are the User associated with the DOI group"], + ) + return creators_list + else: + return [{"name": f"{user.firstname} {user.lastname}", "affiliation": ""}] + + @staticmethod + def create_doi_for_pointings( + pointings: List[Pointing], graceid: str, creators: List[dict], db: Session + ) -> Tuple[Optional[int], Optional[str]]: + """Create a DOI for a list of pointings.""" + if not pointings: + return None, None + + # Get instruments for the pointings + insts = ( + db.query(Instrument) + .filter(Instrument.id.in_([p.instrumentid for p in pointings])) + .all() + ) + inst_set = list(set([i.instrument_name for i in insts])) + + # Get normalized graceid + normalized_gid = GWAlert.alternatefromgraceid(graceid) + + # Create the DOI + result = create_pointing_doi(pointings, normalized_gid, creators, inst_set) + return result diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 00000000..84095a64 --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +# Utils package initialization diff --git a/server/utils/email.py b/server/utils/email.py new file mode 100644 index 00000000..e1473262 --- /dev/null +++ b/server/utils/email.py @@ -0,0 +1,224 @@ +import secrets +import jwt +from datetime import datetime, timedelta +from typing import Optional +from fastapi import HTTPException +from sqlalchemy.orm import Session +from pathlib import Path +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +from server.config import settings + +# Email configuration from central settings +EMAIL_TOKEN_EXPIRE_HOURS = 24 +EMAIL_SECRET_KEY = settings.JWT_SECRET_KEY +EMAIL_ALGORITHM = settings.JWT_ALGORITHM +SMTP_SERVER = settings.MAIL_SERVER +SMTP_PORT = settings.MAIL_PORT +SMTP_USERNAME = settings.MAIL_USERNAME +SMTP_PASSWORD = settings.MAIL_PASSWORD +SENDER_EMAIL = settings.MAIL_DEFAULT_SENDER +BASE_URL = "http://localhost:8000" # This should be configured in settings as well + + +def create_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """ + Create a JWT token with optional expiration + """ + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(hours=EMAIL_TOKEN_EXPIRE_HOURS) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, EMAIL_SECRET_KEY, algorithm=EMAIL_ALGORITHM) + + return encoded_jwt + + +def verify_token(token: str) -> dict: + """ + Verify a JWT token and return the payload + """ + try: + payload = jwt.decode(token, EMAIL_SECRET_KEY, algorithms=[EMAIL_ALGORITHM]) + return payload + except jwt.PyJWTError: + raise HTTPException(status_code=400, detail="Invalid token") + + +def send_account_validation_email(user, db: Session) -> None: + """ + Send an account validation email to a user + """ + # Generate token + token = create_token( + data={"sub": user.email, "id": user.id}, + expires_delta=timedelta(hours=EMAIL_TOKEN_EXPIRE_HOURS), + ) + + # Create verification URL + verification_url = f"{BASE_URL}/verify-account?token={token}" + + # Email subject and content + subject = "Verify your GWTM account" + + # Email content - HTML + html_content = f""" + + + + + +
+

Welcome to GWTM!

+

Thank you for registering with the Gravitational-Wave Treasure Map.

+

Please click the button below to verify your account:

+ Verify Account +

This link will expire in {EMAIL_TOKEN_EXPIRE_HOURS} hours.

+

If you did not register for a GWTM account, please ignore this email.

+
+ + + """ + + # Email content - Plain text + text_content = f""" + Welcome to GWTM! + + Thank you for registering with the Gravitational-Wave Treasure Map. + + Please click the link below to verify your account: + {verification_url} + + This link will expire in {EMAIL_TOKEN_EXPIRE_HOURS} hours. + + If you did not register for a GWTM account, please ignore this email. + """ + + # In a development environment, we might just log the verification URL + print(f"Verification URL for {user.email}: {verification_url}") + + # In a production environment, we would send the actual email + try: + # Create message + message = MIMEMultipart("alternative") + message["Subject"] = subject + message["From"] = SENDER_EMAIL + message["To"] = user.email + + # Attach parts + part1 = MIMEText(text_content, "plain") + part2 = MIMEText(html_content, "html") + message.attach(part1) + message.attach(part2) + + # Connect to server and send + # In a real environment, use a proper production setup + # This is commented out to avoid actual email sending + """ + with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as server: + server.starttls() + server.login(SMTP_USERNAME, SMTP_PASSWORD) + server.sendmail(SENDER_EMAIL, user.email, message.as_string()) + """ + + # For development, we could save the email to a file + # email_path = Path(f"./emails/{user.id}_{datetime.now().strftime('%Y%m%d%H%M%S')}.html") + # email_path.parent.mkdir(exist_ok=True) + # email_path.write_text(html_content) + + # Update the user's token + user.token = token + db.commit() + + return True + except Exception as e: + print(f"Error sending email: {e}") + return False + + +def send_password_reset_email(user, db: Session) -> None: + """ + Send a password reset email to a user + """ + # Generate token + token = create_token( + data={"sub": user.email, "id": user.id, "type": "password_reset"}, + expires_delta=timedelta(hours=1), # Shorter expiry for password resets + ) + + # Create reset URL + reset_url = f"{BASE_URL}/reset-password?token={token}" + + # Email subject and content + subject = "Reset your GWTM password" + + # Email content - HTML + html_content = f""" + + + + + +
+

Password Reset Request

+

We received a request to reset your GWTM password.

+

Please click the button below to reset your password:

+ Reset Password +

This link will expire in 1 hour.

+

If you did not request a password reset, please ignore this email.

+
+ + + """ + + # Email content - Plain text + text_content = f""" + Password Reset Request + + We received a request to reset your GWTM password. + + Please click the link below to reset your password: + {reset_url} + + This link will expire in 1 hour. + + If you did not request a password reset, please ignore this email. + """ + + # In a development environment, we might just log the reset URL + print(f"Password reset URL for {user.email}: {reset_url}") + + # Store the token in the user record + user.reset_token = token + user.reset_token_expires = datetime.utcnow() + timedelta(hours=1) + db.commit() diff --git a/server/utils/error_handling.py b/server/utils/error_handling.py new file mode 100644 index 00000000..0c8329eb --- /dev/null +++ b/server/utils/error_handling.py @@ -0,0 +1,68 @@ +from fastapi import HTTPException, status +from typing import List, Dict, Any, Optional, Union + + +class ErrorDetail: + """Standardized error detail structure""" + + def __init__(self, message: str, code: str = None, params: Dict[str, Any] = None): + self.message = message + self.code = code + self.params = params or {} + + def to_dict(self) -> Dict[str, Any]: + result = {"message": self.message} + if self.code: + result["code"] = self.code + if self.params: + result["params"] = self.params + return result + + +def validation_exception( + message: str = "Validation error", + errors: List[Union[str, Dict, ErrorDetail]] = None, +) -> HTTPException: + """Create a standardized validation error exception""" + detail = {"message": message} + + if errors: + formatted_errors = [] + for error in errors: + if isinstance(error, str): + formatted_errors.append({"message": error}) + elif isinstance(error, ErrorDetail): + formatted_errors.append(error.to_dict()) + elif isinstance(error, dict): + formatted_errors.append(error) + detail["errors"] = formatted_errors + + return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) + + +def permission_exception(message: str = "Permission denied") -> HTTPException: + """Create a standardized permission error exception""" + return HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail={"message": message} + ) + + +def not_found_exception(message: str = "Resource not found") -> HTTPException: + """Create a standardized not found exception""" + return HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail={"message": message} + ) + + +def server_exception(message: str = "Internal server error") -> HTTPException: + """Create a standardized server error exception""" + return HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"message": message} + ) + + +def conflict_exception(message: str = "Resource conflict") -> HTTPException: + """Create a standardized conflict exception""" + return HTTPException( + status_code=status.HTTP_409_CONFLICT, detail={"message": message} + ) diff --git a/server/utils/function.py b/server/utils/function.py new file mode 100644 index 00000000..71a116c4 --- /dev/null +++ b/server/utils/function.py @@ -0,0 +1,594 @@ +import io +import json + +import re + +import ephem +import geoalchemy2 +import numpy as np +import math +from typing import List, Dict, Any, Tuple, Optional + +import requests + +from server import config +from server.core.enums.pointingstatus import PointingStatus +from server.db.models.pointing import Pointing +from server.schemas.pointing import PointingSchema + + +def isInt(s) -> bool: + """Check if a value can be converted to an integer.""" + try: + int(s) + return True + except (ValueError, TypeError): + return False + + +def isFloat(s) -> bool: + """Check if a value can be converted to a float.""" + try: + float(s) + return True + except (ValueError, TypeError): + return False + + +def get_farrate_farunit(far: float) -> Tuple[float, str]: + """ + Convert FAR (False Alarm Rate) to human readable format. + + Args: + far: False Alarm Rate in Hz + + Returns: + Tuple of (rate, unit) where unit is a time unit (years, days, hours, etc.) + """ + far_rate_dict = { + "second": 1 / far, + "minute": 1 / far / 60, + "hour": 1 / far / 3600, + "day": 1 / far / 3600 / 24, + "month": 1 / far / 3600 / 24 / 30, + "year": 1 / far / 3600 / 24 / 365, + "decade": 1 / far / 3600 / 24 / 365 / 10, + "century": 1 / far / 3600 / 24 / 365 / 100, + "millennium": 1 / far / 3600 / 24 / 365 / 1000, + } + + if far_rate_dict["second"] < 60: + return far_rate_dict["second"], "seconds" + if far_rate_dict["minute"] < 60: + return far_rate_dict["minute"], "minutes" + if far_rate_dict["hour"] < 24: + return far_rate_dict["hour"], "hours" + if far_rate_dict["day"] < 30: + return far_rate_dict["day"], "days" + if far_rate_dict["month"] < 12: + return far_rate_dict["month"], "months" + if far_rate_dict["year"] < 10: + return far_rate_dict["year"], "years" + if far_rate_dict["decade"] < 10: + return far_rate_dict["decade"], "decades" + if far_rate_dict["century"] < 10: + return far_rate_dict["century"], "centuries" + + return far_rate_dict["millennium"], "millennia" + + +def sanatize_pointing(position: str) -> Tuple[float, float]: + """ + Extract RA and Dec from a position string. + + Args: + position: String representation of a point, e.g., "POINT(123.456 -45.678)" + + Returns: + Tuple of (ra, dec) as floats + """ + try: + coords = position.split("POINT(")[1].split(")")[0].split(" ") + ra = float(coords[0]) + dec = float(coords[1]) + return ra, dec + except (IndexError, ValueError): + return 0.0, 0.0 + + +def sanatize_footprint_ccds( + footprint_ccds: List[str], +) -> List[List[Tuple[float, float]]]: + """ + Convert footprint strings to coordinate lists. + + Args: + footprint_ccds: List of footprint strings (WKT format) + + Returns: + List of coordinate lists, where each coordinate is an (ra, dec) tuple + """ + result = [] + for footprint in footprint_ccds: + try: + # Extract coordinates from POLYGON string + coords_str = re.search(r"POLYGON\(\((.*)\)\)", footprint).group(1) + coords_pairs = coords_str.split(",") + coords = [] + for pair in coords_pairs: + x, y = map(float, pair.strip().split(" ")) + coords.append((x, y)) + result.append(coords) + except (AttributeError, ValueError): + continue + return result + + +def project_footprint( + footprint: List[Tuple[float, float]], + ra: float, + dec: float, + pos_angle: Optional[float] = None, +) -> List[Tuple[float, float]]: + """ + Project a footprint to a new position with optional rotation. + + Args: + footprint: List of (ra, dec) tuples defining the footprint + ra: Right ascension of the center + dec: Declination of the center + pos_angle: Position angle for rotation (degrees) + + Returns: + Projected footprint as a list of (ra, dec) tuples + """ + # Calculate center of the footprint + ra_vals = [p[0] for p in footprint] + dec_vals = [p[1] for p in footprint] + center_ra = np.mean(ra_vals) + center_dec = np.mean(dec_vals) + + # Calculate offsets from center + ra_offsets = [p[0] - center_ra for p in footprint] + dec_offsets = [p[1] - center_dec for p in footprint] + + # Apply rotation if needed + if pos_angle is not None and pos_angle != 0: + angle_rad = math.radians(pos_angle) + rotated_offsets = [] + for i in range(len(ra_offsets)): + new_ra_offset = ra_offsets[i] * math.cos(angle_rad) - dec_offsets[ + i + ] * math.sin(angle_rad) + new_dec_offset = ra_offsets[i] * math.sin(angle_rad) + dec_offsets[ + i + ] * math.cos(angle_rad) + rotated_offsets.append((new_ra_offset, new_dec_offset)) + ra_offsets = [offset[0] for offset in rotated_offsets] + dec_offsets = [offset[1] for offset in rotated_offsets] + + # Apply offsets to new center + projected_footprint = [] + for i in range(len(footprint)): + projected_ra = ra + ra_offsets[i] / math.cos(math.radians(dec)) + projected_dec = dec + dec_offsets[i] + projected_footprint.append((projected_ra, projected_dec)) + + # Make sure the polygon is closed + if projected_footprint[0] != projected_footprint[-1]: + projected_footprint.append(projected_footprint[0]) + + return projected_footprint + + +def polygons2footprints( + polygons: List[List[List[float]]], time: float = 0 +) -> List[Dict[str, Any]]: + """ + Convert list of polygon coordinates to footprint objects with time. + + Args: + polygons: List of polygon coordinate lists + time: Time value to associate with the footprints + + Returns: + List of footprint objects with 'polygon' and 'time' keys + """ + footprints = [] + for poly in polygons: + # Convert from [lon, lat] to [ra, dec] + footprint = [(coord[0], coord[1]) for coord in poly] + footprints.append({"polygon": footprint, "time": time}) + return footprints + + +def sanatize_gal_info(galaxy_entry, galaxy_list) -> Dict[str, Any]: + """ + Format galaxy information for display. + + Args: + galaxy_entry: A galaxy entry object + galaxy_list: The parent galaxy list object + + Returns: + Formatted galaxy information as a dictionary + """ + info_dict = {} + + if hasattr(galaxy_entry, "info") and galaxy_entry.info: + # Check if info is already a dict (SQLAlchemy JSON column) or a string + if isinstance(galaxy_entry.info, dict): + info_dict = galaxy_entry.info.copy() + elif isinstance(galaxy_entry.info, str): + try: + info_dict = json.loads(galaxy_entry.info) + except json.JSONDecodeError: + info_dict = {} + else: + # Handle other types by converting to empty dict + info_dict = {} + + # Add galaxy list information + info_dict["Group"] = ( + galaxy_list.groupname if hasattr(galaxy_list, "groupname") else "" + ) + info_dict["Score"] = galaxy_entry.score if hasattr(galaxy_entry, "score") else "" + info_dict["Rank"] = galaxy_entry.rank if hasattr(galaxy_entry, "rank") else "" + + return info_dict + + +def sanatize_icecube_event(event, notice) -> Dict[str, Any]: + """ + Format IceCube event information for display. + + Args: + event: An IceCube event object + notice: The parent IceCube notice object + + Returns: + Formatted IceCube event information as a dictionary + """ + info_dict = { + "Event ID": event.id if hasattr(event, "id") else "", + "Notice ID": notice.id if hasattr(notice, "id") else "", + "RA": event.ra if hasattr(event, "ra") else "", + "Dec": event.dec if hasattr(event, "dec") else "", + "Uncertainty": event.ra_uncertainty if hasattr(event, "ra_uncertainty") else "", + "Event p-value (generic)": ( + event.event_pval_generic if hasattr(event, "event_pval_generic") else "" + ), + "Event p-value (Bayesian)": ( + event.event_pval_bayesian if hasattr(event, "event_pval_bayesian") else "" + ), + "Event DT": event.event_dt if hasattr(event, "event_dt") else "", + } + + return info_dict + + +def sanatize_candidate_info(candidate, ra, dec) -> Dict[str, Any]: + """ + Format candidate information for display. + + Args: + candidate: A candidate object + ra: Right ascension + dec: Declination + + Returns: + Formatted candidate information as a dictionary + """ + info_dict = { + "Candidate Name": ( + candidate.candidate_name if hasattr(candidate, "candidate_name") else "" + ), + "TNS Name": candidate.tns_name if hasattr(candidate, "tns_name") else "", + "TNS URL": candidate.tns_url if hasattr(candidate, "tns_url") else "", + "RA": ra, + "Dec": dec, + "Discovery Date": ( + candidate.discovery_date.isoformat() + if hasattr(candidate, "discovery_date") and candidate.discovery_date + else "" + ), + "Discovery Magnitude": ( + candidate.discovery_magnitude + if hasattr(candidate, "discovery_magnitude") + else "" + ), + "Associated Galaxy": ( + candidate.associated_galaxy + if hasattr(candidate, "associated_galaxy") + else "" + ), + "Associated Galaxy Redshift": ( + candidate.associated_galaxy_redshift + if hasattr(candidate, "associated_galaxy_redshift") + else "" + ), + "Associated Galaxy Distance": ( + candidate.associated_galaxy_distance + if hasattr(candidate, "associated_galaxy_distance") + else "" + ), + } + + return info_dict + + +def sanatize_XRT_source_info(source) -> Dict[str, Any]: + """ + Format XRT source information for display. + + Args: + source: An XRT source object + + Returns: + Formatted XRT source information as a dictionary + """ + info_dict = { + "Alert Identifier": source.get("alert_identifier", ""), + "RA": source.get("right_ascension", ""), + "Dec": source.get("declination", ""), + "Significance": source.get("significance", ""), + "URL": source.get("url", ""), + } + + return info_dict + + +def by_chunk(items: List[Any], n: int) -> List[List[Any]]: + """ + Split a list into chunks of size n. + + Args: + items: The list to split + n: The size of each chunk + + Returns: + List of chunks + """ + chunks = [] + for i in range(0, len(items), n): + chunks.append(items[i : i + n]) + return chunks + + +def floatNone(i): + if i is not None: + try: + return float(i) + except: # noqa: E722 + return 0.0 + else: + return None + + +def pointing_crossmatch(pointing, otherpointings, dist_thresh=None): + if dist_thresh is None: + + filtered_pointings = [ + x + for x in otherpointings + if ( + x.status == pointing.status + and x.instrumentid == int(pointing.instrumentid) + and x.band == pointing.band + and x.time == pointing.time + and x.pos_angle == floatNone(pointing.pos_angle) + ) + ] + + for p in filtered_pointings: + p_pos = str(geoalchemy2.shape.to_shape(p.position)) + if sanatize_pointing(p_pos) == sanatize_pointing(pointing.position): + return True + + else: + + p_ra, p_dec = sanatize_pointing(pointing.position) + + filtered_pointings = [ + x + for x in otherpointings + if ( + x.status == pointing.status + and x.instrumentid == int(pointing.instrumentid) + and x.band == pointing.band + ) + ] + + for p in filtered_pointings: + ra, dec = sanatize_pointing(str(geoalchemy2.shape.to_shape(p.position))) + sep = 206264.806 * (float(ephem.separation((ra, dec), (p_ra, p_dec)))) + if sep < dist_thresh: + return True + + return False + + +def create_doi(payload): + ACCESS_TOKEN = config.settings.ZENODO_ACCESS_KEY + data = payload["data"] + data_file = payload["data_file"] + files = payload["files"] + headers = payload["headers"] + + r = requests.post( + "https://zenodo.org/api/deposit/depositions", + params={"access_token": ACCESS_TOKEN}, + json={}, + headers=headers, + ) + + if r.status_code == 403: + return None, None + + d_id = r.json()["id"] + r = requests.post( + "https://zenodo.org/api/deposit/depositions/%s/files" % d_id, + params={"access_token": ACCESS_TOKEN}, + data=data_file, + files=files, + ) + r = requests.put( + "https://zenodo.org/api/deposit/depositions/%s" % d_id, + data=json.dumps(data), + params={"access_token": ACCESS_TOKEN}, + headers=headers, + ) + r = requests.post( + "https://zenodo.org/api/deposit/depositions/%s/actions/publish" % d_id, + params={"access_token": ACCESS_TOKEN}, + ) + + return_json = r.json() + try: + doi_url = return_json["doi_url"] + except: # noqa: E722 + doi_url = None + return int(d_id), doi_url + + +def create_pointing_doi( + points: List[Pointing], + graceid: str, + creators: List[Dict[str, str]], + instrument_names: List[str], +) -> Tuple[int, Optional[str]]: + """ + Create a DOI for pointings. + + Args: + points: List of pointing objects + graceid: Grace ID of the event + creators: List of creator dictionaries + instrument_names: List of instrument names + + Returns: + Tuple of (doi_id, doi_url) + """ + points_json = [] + + for p in points: + if p.status == PointingStatus.completed: + points_json.append(PointingSchema.from_orm(p)) + + if len(instrument_names) > 1: + inst_str = "These observations were taken on the" + for i in instrument_names: + if i == instrument_names[len(instrument_names) - 1]: + inst_str += " and " + i + else: + inst_str += " " + i + "," + + inst_str += " instruments." + else: + inst_str = ( + "These observations were taken on the " + + instrument_names[0] + + " instrument." + ) + + if len(points_json): + payload = { + "data": { + "metadata": { + "title": "Submitted Completed pointings to the Gravitational Wave Treasure Map for event " + + graceid, + "upload_type": "dataset", + "creators": creators, + "description": "Attached in a .json file is the completed pointing information for " + + str(len(points_json)) + + " observation(s) for the EM counterpart search associated with the gravitational wave event " + + graceid + + ". " + + inst_str, + } + }, + "data_file": {"name": "completed_pointings_" + graceid + ".json"}, + "files": { + "file": json.dumps([p.model_dump(mode="json") for p in points_json]) + }, + "headers": {"Content-Type": "application/json"}, + } + + d_id, url = create_doi(payload) + return d_id, url + + return None, None + + +def create_galaxy_score_doi( + galaxies: List[Any], + creators: List[Dict[str, str]], + reference: Optional[str], + graceid: str, + alert_type: str, +) -> Tuple[int, Optional[str]]: + """ + Create a DOI for galaxy scores. + + Args: + galaxies: List of galaxy objects + creators: List of creator dictionaries + reference: Reference information + graceid: Grace ID of the event + alert_type: Type of the alert + + Returns: + Tuple of (doi_id, doi_url) + """ + import uuid + from datetime import datetime + + # Create a unique identifier + doi_suffix = str(uuid.uuid4())[:8] + doi_prefix = "10.5072" # Test DOI prefix + doi = f"{doi_prefix}/gwtm.galaxy.{doi_suffix}" + + # Format the current date + date = datetime.now().strftime("%Y-%m-%d") + + # Create a title + title = f"Galaxy candidates for {graceid} - {alert_type}" + + # Prepare metadata for DOI service + metadata = { + "doi": doi, + "creators": creators, + "titles": [{"title": title}], + "publisher": "Gravitational-Wave Treasure Map", + "publicationYear": date[:4], + "resourceType": {"resourceTypeGeneral": "Dataset"}, + "descriptions": [ + { + "description": f"Galaxy candidates for gravitational-wave event {graceid}", + "descriptionType": "Abstract", + } + ], + } + + # Add reference if provided + if reference: + metadata["relatedIdentifiers"] = [ + { + "relatedIdentifier": reference, + "relatedIdentifierType": "DOI", + "relationType": "References", + } + ] + + # In a real implementation, we would make an API call to the DOI service + # e.g., DataCite, with the metadata + # Here we're simulating that + + # This would be the DOI URL from the service + doi_url = f"https://doi.org/{doi}" + + # This would be the ID returned from the DOI service + # Here we're generating a random number based on the UUID + doi_id = int(doi_suffix, 16) % 1000000 + + return doi_id, doi_url diff --git a/server/utils/gwtm_io.py b/server/utils/gwtm_io.py new file mode 100644 index 00000000..e9af05f7 --- /dev/null +++ b/server/utils/gwtm_io.py @@ -0,0 +1,299 @@ +import fsspec +import json +import os +import tempfile + + +def _get_fs(source, config): + """ + Get the appropriate filesystem based on source. + + Args: + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + + Returns: + Filesystem object + """ + try: + if source == "s3": + return fsspec.filesystem( + "s3", key=config.AWS_ACCESS_KEY_ID, secret=config.AWS_SECRET_ACCESS_KEY + ) + if source == "abfs": + return fsspec.filesystem( + "abfs", + account_name=config.AZURE_ACCOUNT_NAME, + account_key=config.AZURE_ACCOUNT_KEY, + ) + except Exception as e: + raise Exception(f"Error in creating {source} filesystem: {str(e)}") + + +def download_gwtm_file(filename, source="s3", config=None, decode=True): + """ + Download a file from the GWTM storage. + + Args: + filename: File path/name to download + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + decode: Whether to decode the file content to UTF-8 + + Returns: + File content (string if decode=True, bytes if decode=False) + """ + fs = _get_fs(source=source, config=config) + + if source == "s3" and f"{config.AWS_BUCKET}/" not in filename: + filename = f"{config.AWS_BUCKET}/{filename}" + + try: + s3file = fs.open(filename, "rb") + + with s3file as _file: + if decode: + return _file.read().decode("utf-8") + else: + return _file.read() + except Exception as e: + # In development mode, we might want to simulate file access + # This would allow the system to function without actual cloud storage + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + # Check if we have a local development directory + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + local_path = os.path.join(dev_dir, filename.split("/")[-1]) + + # If the file exists locally, return its contents + if os.path.exists(local_path): + with open(local_path, "rb") as f: + content = f.read() + return content.decode("utf-8") if decode else content + + # If we're not in development mode or couldn't find a local file + raise Exception(f"Error reading {source} file {filename}: {str(e)}") + + +def upload_gwtm_file(content, filename, source="s3", config=None): + """ + Upload a file to GWTM storage. + + Args: + content: File content to upload + filename: Destination file path/name + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + + Returns: + True if upload successful + """ + # In development mode, we might want to simulate file upload + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + os.makedirs(dev_dir, exist_ok=True) + local_path = os.path.join(dev_dir, filename.split("/")[-1]) + + # Write to local file + mode = "wb" if isinstance(content, bytes) else "w" + with open(local_path, mode) as f: + f.write(content) + return True + + # Normal cloud storage upload + fs = _get_fs(source=source, config=config) + + if source == "s3" and f"{config.AWS_BUCKET}/" not in filename: + filename = f"{config.AWS_BUCKET}/{filename}" + + try: + mode = "wb" if isinstance(content, bytes) else "w" + with fs.open(filename, mode) as f: + f.write(content) + return True + except Exception as e: + raise Exception(f"Error uploading to {source} file {filename}: {str(e)}") + + +def list_gwtm_bucket(container, source="s3", config=None): + """ + List contents of a bucket/container. + + Args: + container: Container/folder to list + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + + Returns: + List of files in the container + """ + # In development mode, we might want to simulate bucket listing + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + container_dir = os.path.join(dev_dir, container) + + if os.path.exists(container_dir) and os.path.isdir(container_dir): + return sorted( + [os.path.join(container, f) for f in os.listdir(container_dir)] + ) + elif os.path.exists(dev_dir): + # If the specific container doesn't exist, list all files that match the prefix + return sorted([f for f in os.listdir(dev_dir) if f.startswith(container)]) + return [] + + # Normal cloud storage listing + fs = _get_fs(source=source, config=config) + + try: + if source == "s3": + bucket_content = fs.ls(f"{config.AWS_BUCKET}/{container}") + ret = [] + for b in bucket_content: + split_b = b.split(f"{config.AWS_BUCKET}/")[1] + if split_b != f"{container}/": + ret.append(split_b) + return sorted(ret) + + ret = fs.ls(container) + return sorted(ret) + except Exception as e: + # If listing fails (e.g., container doesn't exist), return empty list + return [] + + +def delete_gwtm_files(keys, source="s3", config=None): + """ + Delete files from GWTM storage. + + Args: + keys: Single key or list of keys to delete + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + + Returns: + True if delete successful + """ + # In development mode, we might want to simulate file deletion + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + + # Convert single key to list + if isinstance(keys, str): + keys = [keys] + + # Delete local files + for key in keys: + local_path = os.path.join(dev_dir, key.split("/")[-1]) + if os.path.exists(local_path): + os.remove(local_path) + return True + + # Normal cloud storage deletion + if source == "s3": + if isinstance(keys, list): + for i, k in enumerate(keys): + if f"{config.AWS_BUCKET}/" not in k: + keys[i] = f"{config.AWS_BUCKET}/{k}" + elif isinstance(keys, str) and f"{config.AWS_BUCKET}/" not in keys: + keys = f"{config.AWS_BUCKET}/{keys}" + keys = [keys] # Convert to list for consistency + + fs = _get_fs(source=source, config=config) + + try: + for k in keys: + fs.rm(k) + return True + except Exception as e: + raise Exception(f"Error deleting from {source}: {str(e)}") + + +def get_cached_file(key, config): + """ + Get a cached file from storage. + + Args: + key: Cache key + config: Configuration object with credentials + + Returns: + File content or None if not found + """ + # In development mode, we might want to use a local cache + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + cache_dir = os.path.join(dev_dir, "cache") + cache_file = os.path.join(cache_dir, key.split("/")[-1]) + + if os.path.exists(cache_file): + with open(cache_file, "r") as f: + return f.read() + return None + + # Normal cloud storage cache access + source = config.STORAGE_BUCKET_SOURCE + + try: + cached_files = list_gwtm_bucket("cache", source, config) + + if key in cached_files: + return download_gwtm_file(key, source, config) + else: + return None + except Exception: + return None + + +def set_cached_file(key, contents, config): + """ + Set a cached file in storage. + + Args: + key: Cache key + contents: Content to cache (will be JSON serialized) + config: Configuration object with credentials + + Returns: + True if successful + """ + # In development mode, we might want to use a local cache + if hasattr(config, "DEVELOPMENT_MODE") and config.DEVELOPMENT_MODE: + dev_dir = getattr(config, "DEVELOPMENT_STORAGE_DIR", "./dev_storage") + cache_dir = os.path.join(dev_dir, "cache") + os.makedirs(cache_dir, exist_ok=True) + + cache_file = os.path.join(cache_dir, key.split("/")[-1]) + + with open(cache_file, "w") as f: + json.dump(contents, f) + return True + + # Normal cloud storage cache setting + source = config.STORAGE_BUCKET_SOURCE + + try: + return upload_gwtm_file(json.dumps(contents), key, source, config) + except Exception: + return False + + +def download_to_temp_file(filename, source="s3", config=None): + """ + Download a file to a temporary file and return the path. + Useful for binary files like FITS files that need to be processed by external libraries. + + Args: + filename: File to download + source: Storage source ('s3' or 'abfs') + config: Configuration object with credentials + + Returns: + Path to temporary file + """ + content = download_gwtm_file(filename, source, config, decode=False) + + # Create a temporary file + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(content) + temp_file.close() + + return temp_file.name diff --git a/server/utils/pointing.py b/server/utils/pointing.py new file mode 100644 index 00000000..86227f9e --- /dev/null +++ b/server/utils/pointing.py @@ -0,0 +1,192 @@ +""" +Pointing business logic utilities. +Contains logic that requires database access and can't be moved to Pydantic validators. +""" + +from sqlalchemy.orm import Session +from typing import List, Optional, Tuple, TYPE_CHECKING +from datetime import datetime + +if TYPE_CHECKING: + from server.schemas.pointing import PointingCreate + +from server.db.models.pointing import Pointing +from server.db.models.instrument import Instrument +from server.db.models.pointing_event import PointingEvent +from server.db.models.gw_alert import GWAlert +from server.db.models.users import Users +from server.db.models.doi_author import DOIAuthor +from server.core.enums.pointingstatus import PointingStatus as pointing_status_enum +from server.utils.error_handling import validation_exception, not_found_exception +from server.utils.function import pointing_crossmatch, create_pointing_doi + + +def validate_graceid(graceid: str, db: Session) -> str: + """Validate that a graceid exists in the database.""" + valid_alerts = db.query(GWAlert).filter(GWAlert.graceid == graceid).all() + if len(valid_alerts) == 0: + raise validation_exception( + message="Invalid graceid", + errors=[f"The graceid '{graceid}' does not exist in the database"], + ) + return graceid + + +def validate_instrument(instrument_id: int, db: Session) -> Instrument: + """Validate that an instrument exists.""" + instrument = db.query(Instrument).filter(Instrument.id == instrument_id).first() + if not instrument: + raise not_found_exception(f"Instrument with ID {instrument_id} not found") + return instrument + + +def get_instruments_dict(db: Session) -> dict: + """Get a dictionary of instruments for validation.""" + dbinsts = db.query(Instrument.instrument_name, Instrument.id).all() + return {inst.id: inst.instrument_name for inst in dbinsts} + + +def validate_instrument_reference( + pointing_data: "PointingCreate", instruments_dict: dict +) -> int: + """Validate and resolve instrument reference to ID.""" + if pointing_data.instrumentid is None: + raise validation_exception( + message="Field instrumentid is required", + errors=["instrumentid must be provided"], + ) + + inst = pointing_data.instrumentid + if isinstance(inst, int): + if inst not in instruments_dict: + raise validation_exception( + message="Invalid instrumentid", + errors=[f"Instrument with ID {inst} not found"], + ) + return inst + else: + # Handle string instrument names + for inst_id, inst_name in instruments_dict.items(): + if inst_name == inst: + return inst_id + raise validation_exception( + message="Invalid instrumentid", errors=[f"Instrument '{inst}' not found"] + ) + + +def check_duplicate_pointing( + pointing: Pointing, existing_pointings: List[Pointing] +) -> bool: + """Check if a pointing is a duplicate of existing pointings.""" + return pointing_crossmatch(pointing, existing_pointings) + + +def create_pointing_from_schema( + pointing_data: "PointingCreate", user_id: int, instrument_id: int +) -> Pointing: + """Create a Pointing model instance from schema data.""" + pointing = Pointing() + + # Set basic fields + pointing.position = pointing_data.position + pointing.depth = pointing_data.depth + pointing.depth_err = pointing_data.depth_err + pointing.depth_unit = pointing_data.depth_unit + pointing.status = pointing_data.status or pointing_status_enum.completed + pointing.band = pointing_data.band + pointing.central_wave = pointing_data.central_wave + pointing.bandwidth = pointing_data.bandwidth + pointing.instrumentid = instrument_id + pointing.pos_angle = pointing_data.pos_angle + pointing.time = pointing_data.time + pointing.submitterid = user_id + pointing.datecreated = datetime.now() + + return pointing + + +def handle_planned_pointing_update( + pointing_data, user_id: int, db: Session +) -> Optional[Pointing]: + """Handle updating a planned pointing to completed.""" + if not hasattr(pointing_data, "id") or pointing_data.id is None: + return None + + pointing_id = int(pointing_data.id) + + # Find the planned pointing + planned_pointing = ( + db.query(Pointing) + .filter(Pointing.id == pointing_id, Pointing.submitterid == user_id) + .first() + ) + + if not planned_pointing: + raise validation_exception( + message="Pointing validation error", + errors=[f"Pointing with ID {pointing_id} not found or not owned by you"], + ) + + if planned_pointing.status in [ + pointing_status_enum.completed, + pointing_status_enum.cancelled, + ]: + raise validation_exception( + message="Pointing validation error", + errors=[f"This pointing has already been {planned_pointing.status.name}"], + ) + + # Update planned pointing with new data + if pointing_data.time: + planned_pointing.time = pointing_data.time + if pointing_data.pos_angle is not None: + planned_pointing.pos_angle = pointing_data.pos_angle + + planned_pointing.status = pointing_status_enum.completed + planned_pointing.dateupdated = datetime.now() + + return planned_pointing + + +def prepare_doi_creators( + creators: Optional[List[dict]], + doi_group_id: Optional[int], + user: Users, + db: Session, +) -> List[dict]: + """Prepare DOI creators list.""" + if creators: + return creators + elif doi_group_id: + valid, creators_list = DOIAuthor.construct_creators(doi_group_id, user.id, db) + if not valid: + raise validation_exception( + message="Invalid DOI group ID", + errors=["Make sure you are the User associated with the DOI group"], + ) + return creators_list + else: + return [{"name": f"{user.firstname} {user.lastname}", "affiliation": ""}] + + +def create_doi_for_pointings( + pointings: List[Pointing], graceid: str, creators: List[dict], db: Session +) -> Tuple[Optional[int], Optional[str]]: + """Create a DOI for a list of pointings.""" + if not pointings: + return None, None + + # Get instruments for the pointings + insts = ( + db.query(Instrument) + .filter(Instrument.id.in_([p.instrumentid for p in pointings])) + .all() + ) + inst_set = list(set([i.instrument_name for i in insts])) + + # Get normalized graceid + normalized_gid = GWAlert.alternatefromgraceid(graceid) + + # Create the DOI + result = create_pointing_doi(pointings, normalized_gid, creators, inst_set) + return result diff --git a/server/utils/spectral.py b/server/utils/spectral.py new file mode 100644 index 00000000..ab25a8ee --- /dev/null +++ b/server/utils/spectral.py @@ -0,0 +1,312 @@ +""" +Spectral range utilities for the GWTM application. +These functions handle conversions between different spectral representations +like wavelength, frequency, and energy. +""" + +from typing import Tuple, Optional +import numpy as np +from server.core.enums.bandpass import Bandpass +from enum import IntEnum + +# Constants +SPEED_OF_LIGHT = 299792458 # m/s +PLANCK_CONSTANT = 6.62607015e-34 # J*s +ELECTRON_VOLT = 1.602176634e-19 # J + + +class SpectralRangeHandler: + """ + Values for the central wave and bandwidth were taken from: + http://svo2.cab.inta-csic.es/theory/fps/index.php?mode=browse + notated by the 'source' field in the following dictionary + for central_wavelength I used the lam_cen + for bandwidth I used the FWHM + + Our base for the central wavelengths and bandwidths will be stored in + Angstroms + + There are following static methods to convert the Angstrom values into ranges for + frequency in Hz + energy in eV + """ + + class spectralrangetype(IntEnum): + wavelength = 1 + energy = 2 + frequency = 3 + + # Bandpass wavelength dictionary with central wavelengths and bandwidths in Angstroms + bandpass_wavelength_dictionary = { + Bandpass.U: { + "source": "CTIO/SOI.bessel_U", + "central_wave": 3614.82, + "bandwidth": 617.24, + }, + Bandpass.B: { + "source": "CTIO/SOI.bessel_B", + "central_wave": 4317.0, + "bandwidth": 991.48, + }, + Bandpass.V: { + "source": "CTIO/SOI.bessel_V", + "central_wave": 5338.65, + "bandwidth": 810.65, + }, + Bandpass.R: { + "source": "CTIO/SOI.bessel_R", + "central_wave": 6311.86, + "bandwidth": 1220.89, + }, + Bandpass.I: { + "source": "CTIO/SOI.bessel_I", + "central_wave": 8748.91, + "bandwidth": 2940.57, + }, + Bandpass.J: { + "source": "CTIO/ANDICAM/J", + "central_wave": 12457.00, + "bandwidth": 1608.86, + }, + Bandpass.H: { + "source": "CTIO/ANDICAM/H", + "central_wave": 16333.11, + "bandwidth": 2969.21, + }, + Bandpass.K: { + "source": "CTIO/ANDICAM/K", + "central_wave": 21401.72, + "bandwidth": 2894.54, + }, + Bandpass.u: { + "source": "CTIO/DECam.u_filter", + "central_wave": 3552.98, + "bandwidth": 885.05, + }, + Bandpass.g: { + "source": "CTIO/DECam.g_filter", + "central_wave": 4730.50, + "bandwidth": 1503.06, + }, + Bandpass.r: { + "source": "CTIO/DECam.r_filter", + "central_wave": 6415.40, + "bandwidth": 1487.58, + }, + Bandpass.i: { + "source": "CTIO/DECam.i_filter", + "central_wave": 7836.21, + "bandwidth": 1468.29, + }, + Bandpass.z: { + "source": "CTIO/DECam.z_filter", + "central_wave": 9258.37, + "bandwidth": 1521.09, + }, + Bandpass.UVW1: { + "source": "Swift/UVOT.UVW1", + "central_wave": 2629.35, + "bandwidth": 656.60, + }, + Bandpass.UVW2: { + "source": "Swift/UVOT.UVW2", + "central_wave": 2089.16, + "bandwidth": 498.25, + }, + Bandpass.UVM2: { + "source": "Swift/UVOT.UVM2", + "central_wave": 2245.78, + "bandwidth": 498.25, + }, + Bandpass.clear: { + "source": "Generic/clear", + "central_wave": 2634.44, + "bandwidth": 3230.16, + }, + Bandpass.open: { + "source": "Generic/open", + "central_wave": 5500.0, + "bandwidth": 8000.0, + }, + Bandpass.other: { + "source": "Generic/other", + "central_wave": 5500.0, + "bandwidth": 8000.0, + }, + } + + @staticmethod + def wavetoWaveRange(central_wave=None, bandwidth=None, bandpass_enum=None): + """Method that returns the wavelength range from the central_wave and bandwidth, or bandpass""" + if central_wave is None and bandwidth is None and bandpass_enum is not None: + bp = SpectralRangeHandler.bandpass_wavelength_dictionary[bandpass_enum] + central_wave = bp["central_wave"] + bandwidth = bp["bandwidth"] + + wave_min = central_wave - (bandwidth / 2.0) + wave_max = central_wave + (bandwidth / 2.0) + + return wave_min, wave_max + + @staticmethod + def wavetoEnergy(central_wave=None, bandwidth=None, bandpass_enum=None): + """Method that returns the corresponding wave range to energy in eV""" + wave_min, wave_max = SpectralRangeHandler.wavetoWaveRange( + central_wave, bandwidth, bandpass_enum + ) + + ev_max = 12398 / wave_min + ev_min = 12398 / wave_max + + return ev_min, ev_max + + @staticmethod + def wavetoFrequency(central_wave=None, bandwidth=None, bandpass_enum=None): + """Method that returns the corresponding wave range to frequency in Hz""" + wave_min, wave_max = SpectralRangeHandler.wavetoWaveRange( + central_wave, bandwidth, bandpass_enum + ) + + freq_max = 2997924580000000000.0 / wave_min + freq_min = 2997924580000000000.0 / wave_max + + return freq_min, freq_max + + +def waveToFreq(wave: float) -> float: + """ + Convert wavelength (in Angstroms) to frequency (in Hz). + + Args: + wave: Wavelength in Angstroms + + Returns: + Frequency in Hz + """ + wavelength_m = wave * 1e-10 # Convert from Angstroms to meters + return SPEED_OF_LIGHT / wavelength_m + + +def freqToWave(freq: float) -> float: + """ + Convert frequency (in Hz) to wavelength (in Angstroms). + + Args: + freq: Frequency in Hz + + Returns: + Wavelength in Angstroms + """ + wavelength_m = SPEED_OF_LIGHT / freq # Wavelength in meters + return wavelength_m * 1e10 # Convert to Angstroms + + +def waveToEnergy(wave: float) -> float: + """ + Convert wavelength (in Angstroms) to energy (in eV). + + Args: + wave: Wavelength in Angstroms + + Returns: + Energy in eV + """ + wavelength_m = wave * 1e-10 # Convert from Angstroms to meters + freq = SPEED_OF_LIGHT / wavelength_m # Calculate frequency + energy_J = PLANCK_CONSTANT * freq # Calculate energy in Joules + return energy_J / ELECTRON_VOLT # Convert to eV + + +def energyToWave(energy: float) -> float: + """ + Convert energy (in eV) to wavelength (in Angstroms). + + Args: + energy: Energy in eV + + Returns: + Wavelength in Angstroms + """ + energy_J = energy * ELECTRON_VOLT # Convert to Joules + freq = energy_J / PLANCK_CONSTANT # Calculate frequency + wavelength_m = SPEED_OF_LIGHT / freq # Calculate wavelength in meters + return wavelength_m * 1e10 # Convert to Angstroms + + +def freqToEnergy(freq: float) -> float: + """ + Convert frequency (in Hz) to energy (in eV). + + Args: + freq: Frequency in Hz + + Returns: + Energy in eV + """ + energy_J = PLANCK_CONSTANT * freq # Calculate energy in Joules + return energy_J / ELECTRON_VOLT # Convert to eV + + +def energyToFreq(energy: float) -> float: + """ + Convert energy (in eV) to frequency (in Hz). + + Args: + energy: Energy in eV + + Returns: + Frequency in Hz + """ + energy_J = energy * ELECTRON_VOLT # Convert to Joules + return energy_J / PLANCK_CONSTANT # Calculate frequency + + +def wavetoWaveRange( + bandpass_enum: Bandpass = None, central_wave: float = None, bandwidth: float = None +) -> Tuple[float, float]: + """ + Get the wavelength range for a specific bandpass using SpectralRangeHandler. + + Args: + bandpass_enum: Bandpass enum value + central_wave: Central wavelength in Angstroms (alternative to bandpass) + bandwidth: Bandwidth in Angstroms (alternative to bandpass) + + Returns: + Tuple of (min_wavelength, max_wavelength) in Angstroms + """ + return SpectralRangeHandler.wavetoWaveRange(central_wave, bandwidth, bandpass_enum) + + +def wavetoEnergy( + bandpass_enum: Bandpass = None, central_wave: float = None, bandwidth: float = None +) -> Tuple[float, float]: + """ + Get the energy range for a specific bandpass using SpectralRangeHandler. + + Args: + bandpass_enum: Bandpass enum value + central_wave: Central wavelength in Angstroms (alternative to bandpass) + bandwidth: Bandwidth in Angstroms (alternative to bandpass) + + Returns: + Tuple of (min_energy, max_energy) in eV + """ + return SpectralRangeHandler.wavetoEnergy(central_wave, bandwidth, bandpass_enum) + + +def wavetoFrequency( + bandpass_enum: Bandpass = None, central_wave: float = None, bandwidth: float = None +) -> Tuple[float, float]: + """ + Get the frequency range for a specific bandpass using SpectralRangeHandler. + + Args: + bandpass_enum: Bandpass enum value + central_wave: Central wavelength in Angstroms (alternative to bandpass) + bandwidth: Bandwidth in Angstroms (alternative to bandpass) + + Returns: + Tuple of (min_frequency, max_frequency) in Hz + """ + return SpectralRangeHandler.wavetoFrequency(central_wave, bandwidth, bandpass_enum) diff --git a/src/ajaxrequests.py b/src/ajaxrequests.py index 6d4aa72f..b69f7ed9 100644 --- a/src/ajaxrequests.py +++ b/src/ajaxrequests.py @@ -879,7 +879,7 @@ def plot_renormalized_skymap(): 'dec': dec, 'time':p.time, 'depth':p.depth, - 'depth_unit':p.depth_unit, + 'depth_unit':p.DepthUnit, 'band':p.band, 'status':p.status })) @@ -1062,7 +1062,7 @@ def get_pointing_fromID(): pointing_json['ra'] = ra pointing_json['dec'] = dec pointing_json['graceid'] = pointing.graceid - pointing_json['instrument'] = str(pointing.instrumentid)+'_'+enums.instrument_type(pointing.instrument_type).name + pointing_json['instrument'] = str(pointing.instrumentid)+'_'+enums.instrument_type(pointing.InstrumentType).name pointing_json['band'] = pointing.band.name pointing_json['depth'] = pointing.depth pointing_json['depth_err'] = pointing.depth_err diff --git a/src/forms.py b/src/forms.py index 69f3d923..1034a4ca 100644 --- a/src/forms.py +++ b/src/forms.py @@ -220,7 +220,7 @@ def populate_instruments(self): query = models.instrument.query.all() self.instruments.choices = [(None, 'Select')] for a in query: - self.instruments.choices.append((str(a.id)+"_"+a.instrument_type.name, a.instrument_name)) + self.instruments.choices.append((str(a.id) +"_" + a.InstrumentType.name, a.instrument_name)) def populate_creator_groups(self, current_userid): dag = models.doi_author_group.query.filter_by(userid=current_userid).all() @@ -421,7 +421,7 @@ def construct_alertform(self, args): self.inst_cov.append({'name':inst_name, 'value':inst.id}) self.depth_unit=[] - for dp in list(set([x.depth_unit for x in pointing_info if x.status == enums.pointing_status.completed and x.instrumentid != 49 and x.depth_unit is not None])): + for dp in list(set([x.DepthUnit for x in pointing_info if x.status == enums.pointing_status.completed and x.instrumentid != 49 and x.DepthUnit is not None])): self.depth_unit.append({'name':str(dp), 'value':dp.name}) diff --git a/src/models.py b/src/models.py index 6848942a..6da3aae7 100644 --- a/src/models.py +++ b/src/models.py @@ -826,7 +826,7 @@ def from_json(self, p, dbinsts, userid, planned_pointings, otherpointings): # d self.position = planned_pointing.position self.depth = planned_pointing.depth self.depth_err = planned_pointing.depth_err - self.depth_unit = planned_pointing.depth_unit + self.depth_unit = planned_pointing.DepthUnit self.status = enums.pointing_status.completed self.band = planned_pointing.band self.central_wave = planned_pointing.central_wave diff --git a/tests/fastapi/conftest.py b/tests/fastapi/conftest.py new file mode 100644 index 00000000..d399ce7a --- /dev/null +++ b/tests/fastapi/conftest.py @@ -0,0 +1,138 @@ +""" +Pytest configuration and fixtures for FastAPI tests. +Automatically loads test data before running tests. +""" + +import os +import sys +import subprocess +import time +import pytest +import requests +from pathlib import Path + +# Add server directory to Python path so tests can import local code +server_dir = Path(__file__).parent.parent.parent / "server" +if str(server_dir) not in sys.path: + sys.path.insert(0, str(server_dir)) + + +def load_test_data(): + """Load test data using the existing restore-db script.""" + # Get the path to test-data.sql + test_data_path = Path(__file__).parent.parent / "test-data.sql" + + if not test_data_path.exists(): + pytest.skip(f"Test data file not found: {test_data_path}") + + # Get the path to restore-db script + restore_script = Path(__file__).parent.parent.parent / "gwtm-helm" / "restore-db" + + if not restore_script.exists(): + pytest.skip(f"restore-db script not found: {restore_script}") + + try: + # Run the restore-db script with proper environment and stdin + print(f"Loading test data from {test_data_path}") + print(f"Running: {restore_script} {test_data_path}") + + # Try with a shorter timeout first to see what happens + result = subprocess.run( + [str(restore_script), str(test_data_path)], + capture_output=True, + text=True, + timeout=10, # Short timeout to debug + cwd=Path(__file__).parent.parent.parent, # Run from project root + stdin=subprocess.DEVNULL, # Prevent hanging on input + env=os.environ.copy(), # Pass through environment + ) + + if result.returncode != 0: + print(f"restore-db failed with return code {result.returncode}") + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + pytest.skip("Failed to load test data") + + print("Test data loaded successfully") + print(f"stdout: {result.stdout}") + + except subprocess.TimeoutExpired as e: + print(f"Test data loading timed out after 10 seconds") + print(f"Command: {e.cmd}") + # Try to get partial output + if hasattr(e, "stdout") and e.stdout: + print(f"Partial stdout: {e.stdout}") + if hasattr(e, "stderr") and e.stderr: + print(f"Partial stderr: {e.stderr}") + pytest.skip("Test data loading timed out - check kubectl access") + except Exception as e: + print(f"Error loading test data: {e}") + pytest.skip(f"Error loading test data: {e}") + + +def wait_for_api(): + """Wait for the FastAPI server to be ready.""" + api_url = os.getenv("API_BASE_URL", "http://localhost:8000") + health_url = f"{api_url}/health" + + max_attempts = 30 + for attempt in range(max_attempts): + try: + response = requests.get(health_url, timeout=5) + if response.status_code == 200: + print(f"API is ready at {api_url}") + return + except requests.exceptions.RequestException: + pass + + if attempt < max_attempts - 1: + print(f"Waiting for API to be ready... ({attempt + 1}/{max_attempts})") + time.sleep(2) + + pytest.skip(f"API not ready after {max_attempts} attempts") + + +@pytest.fixture(scope="session", autouse=True) +def setup_test_environment(): + """ + Session-level fixture that runs once before all tests. + Waits for API to be ready and loads test data. + """ + print("Setting up test environment...") + + # Wait for API to be ready + wait_for_api() + + # Load test data + load_test_data() + + print("Test environment setup complete") + + +@pytest.fixture(scope="function") +def api_base_url(): + """Provide the API base URL for tests.""" + return os.getenv("API_BASE_URL", "http://localhost:8000") + + +@pytest.fixture(scope="function") +def api_headers(): + """Provide common API headers.""" + return {"Content-Type": "application/json", "Accept": "application/json"} + + +@pytest.fixture(scope="function") +def test_tokens(): + """Provide test API tokens from test data.""" + return { + "admin": "test_token_admin_001", + "user": "test_token_user_002", + "scientist": "test_token_sci_003", + "invalid": "invalid_token_123", + } + + +@pytest.fixture(scope="function") +def test_graceids(): + """Provide known GraceIDs from test data.""" + return ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] diff --git a/tests/fastapi/test_admin.py b/tests/fastapi/test_admin.py new file mode 100644 index 00000000..9b33ef75 --- /dev/null +++ b/tests/fastapi/test_admin.py @@ -0,0 +1,81 @@ +""" +Test admin endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") + + +class TestAdminEndpoints: + """Test class for admin-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{endpoint}" + + def test_fixdata_as_admin_get(self): + """Test the fixdata endpoint with GET as admin.""" + response = requests.get( + self.get_url("/fixdata"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert data["message"] == "success" + + def test_fixdata_as_admin_post(self): + """Test the fixdata endpoint with POST as admin.""" + response = requests.post( + self.get_url("/fixdata"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert data["message"] == "success" + + def test_fixdata_as_non_admin(self): + """Test that non-admin users cannot use the fixdata endpoint.""" + # Test with regular user token + response = requests.get( + self.get_url("/fixdata"), headers={"api_token": self.user_token} + ) + + assert response.status_code == 403 + + # Test with scientist token + response = requests.get( + self.get_url("/fixdata"), headers={"api_token": self.scientist_token} + ) + + assert response.status_code == 403 + + def test_fixdata_without_auth(self): + """Test that authentication is required for fixdata endpoint.""" + response = requests.get(self.get_url("/fixdata")) + assert response.status_code == 401 + + def test_fixdata_with_invalid_token(self): + """Test fixdata with invalid API token.""" + response = requests.get( + self.get_url("/fixdata"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == 401 + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_candidate.py b/tests/fastapi/test_candidate.py new file mode 100644 index 00000000..abb946aa --- /dev/null +++ b/tests/fastapi/test_candidate.py @@ -0,0 +1,917 @@ +""" +Test candidate endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import json +from datetime import datetime +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestCandidateEndpoints: + """Test class for candidate-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + def test_get_candidates_no_params(self): + """Test getting candidates without any parameters.""" + response = requests.get( + self.get_url("/candidate"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return candidates that exist in test data + for candidate in data: + assert "id" in candidate + assert "candidate_name" in candidate + assert "graceid" in candidate + + def test_get_candidate_by_id(self): + """Test getting a specific candidate by ID.""" + # First create a candidate to ensure we have one + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "Test_SN_001", + "ra": 123.456, + "dec": -12.345, + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "r", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Now get it by ID + response = requests.get( + self.get_url("/candidate"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["id"] == candidate_id + assert data[0]["candidate_name"] == "Test_SN_001" + + def test_get_candidates_by_graceid(self): + """Test getting candidates filtered by graceid.""" + response = requests.get( + self.get_url("/candidate"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned candidates should have the specified graceid + for candidate in data: + assert candidate["graceid"] == "S190425z" + + def test_get_candidates_by_multiple_ids(self): + """Test getting candidates filtered by multiple IDs.""" + # First create a couple of candidates + candidate_ids = [] + for i in range(2): + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": f"Test_Multi_{i}", + "ra": 123.456 + i, + "dec": -12.345 + i, + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5 + i, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "r", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_ids.extend(create_response.json()["candidate_ids"]) + + # Now get them by IDs + ids_param = json.dumps(candidate_ids) # JSON array format + response = requests.get( + self.get_url("/candidate"), + params={"ids": ids_param}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) >= 2 + returned_ids = [c["id"] for c in data] + for cid in candidate_ids: + assert cid in returned_ids + + def test_get_candidates_by_user_id(self): + """Test getting candidates filtered by user ID.""" + response = requests.get( + self.get_url("/candidate"), + params={"userid": 1}, # Admin user ID + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned candidates should be submitted by user 1 + for candidate in data: + assert candidate["submitterid"] == 1 + + def test_get_candidates_by_date_range(self): + """Test getting candidates filtered by discovery date range.""" + response = requests.get( + self.get_url("/candidate"), + params={ + "discovery_date_after": "2019-04-01T00:00:00.000000", + "discovery_date_before": "2019-05-01T00:00:00.000000", + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Check that returned candidates are within the date range + for candidate in data: + if candidate.get("discovery_date"): + discovery_date = datetime.fromisoformat( + candidate["discovery_date"].replace("Z", "+00:00") + ) + assert datetime(2019, 4, 1) <= discovery_date <= datetime(2019, 5, 1) + + def test_get_candidates_by_magnitude_range(self): + """Test getting candidates filtered by magnitude range.""" + response = requests.get( + self.get_url("/candidate"), + params={"discovery_magnitude_gt": 20.0, "discovery_magnitude_lt": 23.0}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Check that returned candidates are within the magnitude range + for candidate in data: + if candidate.get("discovery_magnitude"): + mag = candidate["discovery_magnitude"] + assert 20.0 < mag < 23.0 + + def test_post_single_candidate(self): + """Test posting a single candidate.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_2019abc", + "ra": 150.789, + "dec": -25.456, + "discovery_date": "2019-04-25T14:30:00.000000", + "discovery_magnitude": 22.1, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "g", + "tns_name": "2019abc", + "tns_url": "https://www.wis-tns.org/object/2019abc", + "associated_galaxy": "NGC1234", + "associated_galaxy_redshift": 0.05, + "associated_galaxy_distance": 200.5, + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "candidate_ids" in data + assert len(data["candidate_ids"]) == 1 + assert isinstance(data["candidate_ids"][0], int) + assert len(data.get("ERRORS", [])) == 0 + + def test_post_multiple_candidates(self): + """Test posting multiple candidates.""" + candidate_data = { + "graceid": "S190425z", + "candidates": [ + { + "candidate_name": "SN_2019def", + "ra": 155.123, + "dec": -30.789, + "discovery_date": "2019-04-25T15:00:00.000000", + "discovery_magnitude": 21.8, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "r", + }, + { + "candidate_name": "SN_2019ghi", + "ra": 160.456, + "dec": -35.123, + "discovery_date": "2019-04-25T16:00:00.000000", + "discovery_magnitude": 22.3, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "i", + }, + ], + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "candidate_ids" in data + assert len(data["candidate_ids"]) == 2 + assert len(data.get("ERRORS", [])) == 0 + + def test_post_candidate_with_position_string(self): + """Test posting candidate with position as WKT string.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_2019jkl", + "position": "POINT(165.789 40.123)", + "discovery_date": "2019-04-25T17:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "V", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["candidate_ids"]) == 1 + assert len(data.get("ERRORS", [])) == 0 + + def test_post_candidate_with_spectral_regime(self): + """Test posting candidate with wavelength regime.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_2019mno", + "ra": 170.456, + "dec": 45.789, + "discovery_date": "2019-04-25T18:00:00.000000", + "discovery_magnitude": 20.9, + "magnitude_unit": "ab_mag", + "wavelength_regime": [4000, 7000], + "wavelength_unit": "angstrom", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["candidate_ids"]) == 1 + assert len(data.get("ERRORS", [])) == 0 + + def test_post_candidate_invalid_graceid(self): + """Test posting candidate with invalid graceid.""" + candidate_data = { + "graceid": "INVALID_GID", + "candidate": { + "candidate_name": "SN_Invalid", + "ra": 123.456, + "dec": -12.345, + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid 'graceid'" in response.json()["message"] + + def test_post_candidate_missing_required_fields(self): + """Test posting candidate with missing required fields.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_Incomplete", + "ra": 123.456, + "dec": -12.345, + # Missing discovery_date and discovery_magnitude + "magnitude_unit": "ab_mag", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + missing_fields = [ + field["params"]["field"] for field in response.json()["errors"] + ] + assert "discovery_date" in str(missing_fields) + assert "discovery_magnitude" in str(missing_fields) + + def test_post_candidate_invalid_position(self): + """Test posting candidate with invalid position data.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_BadPos", + # Missing both position and ra/dec + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json()["errors"][0]["message"] + assert "Either position or both ra and dec must be provided" in data + + def test_put_candidate(self): + """Test updating an existing candidate.""" + # First create a candidate + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_ToUpdate", + "ra": 175.123, + "dec": -50.456, + "discovery_date": "2019-04-25T19:00:00.000000", + "discovery_magnitude": 22.0, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "B", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Now update it + update_data = { + "id": candidate_id, + "candidate": { + "candidate_name": "SN_Updated", + "discovery_magnitude": 21.5, + "tns_name": "2019updated", + "associated_galaxy": "NGC5678", + }, + } + + response = requests.put( + self.get_url("/candidate"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "candidate" in data + assert data["candidate"]["candidate_name"] == "SN_Updated" + assert data["candidate"]["discovery_magnitude"] == 21.5 + + def test_put_candidate_nonexistent(self): + """Test updating a non-existent candidate.""" + update_data = { + "id": 99999, # Non-existent ID + "candidate": {"candidate_name": "SN_NotFound", "discovery_magnitude": 21.5}, + } + + response = requests.put( + self.get_url("/candidate"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "No candidate found" in response.json()["message"] + + def test_put_candidate_unauthorized(self): + """Test updating another user's candidate.""" + # Create candidate as admin + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_AdminOwned", + "ra": 180.789, + "dec": -55.123, + "discovery_date": "2019-04-25T20:00:00.000000", + "discovery_magnitude": 21.7, + "magnitude_unit": "ab_mag", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Try to update as different user + update_data = { + "id": candidate_id, + "candidate": {"candidate_name": "SN_Hijacked"}, + } + + response = requests.put( + self.get_url("/candidate"), + json=update_data, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "Unable to alter" in response.json()["message"] + + def test_delete_candidate_single(self): + """Test deleting a single candidate.""" + # First create a candidate + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_ToDelete", + "ra": 185.456, + "dec": -60.789, + "discovery_date": "2019-04-25T21:00:00.000000", + "discovery_magnitude": 23.1, + "magnitude_unit": "ab_mag", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Now delete it + response = requests.delete( + self.get_url("/candidate"), + json={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "Successfully deleted" in data["message"] + assert candidate_id in data["deleted_ids"] + + def test_delete_candidate_multiple(self): + """Test deleting multiple candidates.""" + # First create multiple candidates + candidate_ids = [] + for i in range(3): + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": f"SN_MultiDelete_{i}", + "ra": 190.0 + i, + "dec": -65.0 - i, + "discovery_date": "2019-04-25T22:00:00.000000", + "discovery_magnitude": 22.5 + i * 0.1, + "magnitude_unit": "ab_mag", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_ids.extend(create_response.json()["candidate_ids"]) + + # Now delete them + ids_param = json.dumps(candidate_ids) + response = requests.delete( + self.get_url("/candidate"), + json={"ids": candidate_ids}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "Successfully deleted" in data["message"] + assert len(data["deleted_ids"]) == 3 + for cid in candidate_ids: + assert cid in data["deleted_ids"] + + def test_delete_candidate_nonexistent(self): + """Test deleting a non-existent candidate.""" + response = requests.delete( + self.get_url("/candidate"), + json={"id": 99999}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "No candidate found" in response.json()["message"] + + def test_delete_candidate_unauthorized(self): + """Test deleting another user's candidate.""" + # Create candidate as admin + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_AdminProtected", + "ra": 195.123, + "dec": -70.456, + "discovery_date": "2019-04-25T23:00:00.000000", + "discovery_magnitude": 21.2, + "magnitude_unit": "ab_mag", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Try to delete as different user + response = requests.delete( + self.get_url("/candidate"), + json={"id": candidate_id}, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "Unauthorized" in response.json()["message"] + + def test_candidate_unauthorized_access(self): + """Test that unauthorized requests are rejected.""" + # Request without API token + response = requests.get(self.get_url("/candidate")) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + # Request with invalid API token + response = requests.get( + self.get_url("/candidate"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_candidates_with_different_tokens(self): + """Test access with different valid API tokens.""" + # All authenticated users should be able to read candidates + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/candidate"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + def test_post_candidate_with_different_users(self): + """Test creating candidates as different users.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_UserSubmitted", + "ra": 200.456, + "dec": -75.789, + "discovery_date": "2019-04-26T00:00:00.000000", + "discovery_magnitude": 22.8, + "magnitude_unit": "ab_mag", + }, + } + + # Submit as regular user + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["candidate_ids"]) == 1 + + # Verify the submitter ID by getting the candidate back + candidate_id = data["candidate_ids"][0] + get_response = requests.get( + self.get_url("/candidate"), + params={"id": candidate_id}, + headers={"api_token": self.user_token}, + ) + assert get_response.status_code == status.HTTP_200_OK + candidate = get_response.json()[0] + assert candidate["submitterid"] == 2 # User token corresponds to user ID 2 + + +class TestCandidateAPIValidation: + """Test validation of candidate API endpoints.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_invalid_magnitude_unit(self): + """Test creating candidate with invalid magnitude unit.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_InvalidUnit", + "ra": 123.456, + "dec": -12.345, + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "invalid_unit", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert "Invalid magnitude unit" in response.json()["errors"][0]["message"] + + def test_invalid_date_format(self): + """Test creating candidate with invalid date format.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_BadDate", + "ra": 123.456, + "dec": -12.345, + "discovery_date": "invalid-date-format", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert ( + "Invalid discovery_date format" in response.json()["errors"][0]["message"] + ) + + def test_missing_position_data(self): + """Test creating candidate without position or coordinates.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_NoPos", + # No position, ra, or dec + "discovery_date": "2019-04-25T12:00:00.000000", + "discovery_magnitude": 21.5, + "magnitude_unit": "ab_mag", + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + "Either position or both ra and dec must be provided" + in response.json()["errors"][0]["message"] + ) + + +class TestCandidateAPIIntegration: + """Integration tests for candidate API workflows.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_complete_candidate_workflow(self): + """Test complete workflow: create, read, update, delete candidate.""" + # Step 1: Create candidate + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_WorkflowTest", + "ra": 205.789, + "dec": -80.123, + "discovery_date": "2019-04-26T01:00:00.000000", + "discovery_magnitude": 21.9, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "r", + }, + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["candidate_ids"][0] + + # Step 2: Read candidate + get_response = requests.get( + self.get_url("/candidate"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + assert get_response.status_code == status.HTTP_200_OK + candidate = get_response.json()[0] + assert candidate["candidate_name"] == "SN_WorkflowTest" + + # Step 3: Update candidate + update_data = { + "id": candidate_id, + "candidate": { + "discovery_magnitude": 21.5, + "tns_name": "2019workflow", + "associated_galaxy": "NGC_Workflow", + }, + } + update_response = requests.put( + self.get_url("/candidate"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + assert update_response.status_code == status.HTTP_200_OK + assert update_response.json()["candidate"]["discovery_magnitude"] == 21.5 + + # Step 4: Delete candidate + delete_response = requests.delete( + self.get_url("/candidate"), + json={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + assert delete_response.status_code == status.HTTP_200_OK + assert candidate_id in delete_response.json()["deleted_ids"] + + # Step 5: Verify deletion + verify_response = requests.get( + self.get_url("/candidate"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + assert verify_response.status_code == status.HTTP_200_OK + assert len(verify_response.json()) == 0 # Should be empty + + def test_bulk_operations(self): + """Test bulk creation and deletion of candidates.""" + # Bulk create + candidates_data = { + "graceid": "S190425z", + "candidates": [ + { + "candidate_name": f"SN_Bulk_{i}", + "ra": 210.0 + i, + "dec": -85.0 - i, + "discovery_date": "2019-04-26T02:00:00.000000", + "discovery_magnitude": 22.0 + i * 0.1, + "magnitude_unit": "ab_mag", + } + for i in range(5) + ], + } + + create_response = requests.post( + self.get_url("/candidate"), + json=candidates_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + candidate_ids = create_response.json()["candidate_ids"] + assert len(candidate_ids) == 5 + + # Bulk delete + delete_response = requests.delete( + self.get_url("/candidate"), + json={"ids": candidate_ids}, + headers={"api_token": self.admin_token}, + ) + assert delete_response.status_code == status.HTTP_200_OK + assert len(delete_response.json()["deleted_ids"]) == 5 + + def test_candidate_with_all_optional_fields(self): + """Test creating candidate with all optional fields.""" + candidate_data = { + "graceid": "S190425z", + "candidate": { + "candidate_name": "SN_Complete", + "ra": 215.456, + "dec": 85.789, + "discovery_date": "2019-04-26T03:00:00.000000", + "discovery_magnitude": 20.5, + "magnitude_unit": "ab_mag", + "magnitude_bandpass": "V", + "magnitude_central_wave": 5500.0, + "magnitude_bandwidth": 1000.0, + "tns_name": "2019complete", + "tns_url": "https://www.wis-tns.org/object/2019complete", + "associated_galaxy": "NGC_Complete", + "associated_galaxy_redshift": 0.1, + "associated_galaxy_distance": 450.5, + }, + } + + response = requests.post( + self.get_url("/candidate"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["candidate_ids"]) == 1 + assert len(response.json().get("ERRORS", [])) == 0 + + # Verify all fields were set correctly + candidate_id = response.json()["candidate_ids"][0] + get_response = requests.get( + self.get_url("/candidate"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + candidate = get_response.json()[0] + assert candidate["tns_name"] == "2019complete" + assert candidate["associated_galaxy"] == "NGC_Complete" + assert candidate["associated_galaxy_redshift"] == 0.1 + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_doi.py b/tests/fastapi/test_doi.py new file mode 100644 index 00000000..4998e5c0 --- /dev/null +++ b/tests/fastapi/test_doi.py @@ -0,0 +1,462 @@ +import os +import requests +import datetime +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestDOIEndpoints: + """Test class for DOI-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + def create_completed_pointing(self, graceid, token): + """Helper method to create a completed pointing that's eligible for DOI.""" + pointing_data = { + "graceid": graceid, + "pointing": { + "ra": 123.456 + + ( + datetime.datetime.now().microsecond / 1000000 + ), # Add some randomness + "dec": -12.345 + (datetime.datetime.now().microsecond / 1000000), + "instrumentid": 1, + "depth": 20.5, + "depth_unit": "ab_mag", + "time": datetime.datetime.now().isoformat(), + "status": "completed", # This is crucial - must be completed + "pos_angle": 0.0, + "band": "r", + }, + } + + response = requests.post( + self.get_url("/pointings"), json=pointing_data, headers={"api_token": token} + ) + + if response.status_code != status.HTTP_200_OK: + pytest.fail(f"Failed to create test pointing: {response.text}") + + return response.json()["pointing_ids"][0] + + def test_request_doi_with_single_id(self): + """Test requesting a DOI with a single pointing ID.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Now request a DOI for it + doi_data = { + "id": pointing_id, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + pointing_data = { + "graceid": "S190425z", + } + # Temporarily get all pointings and print them for debugging + ptemp = requests.get( + self.get_url("/pointings"), headers={"api_token": self.admin_token} + ) + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI_URL" in data + + def test_request_doi_with_multiple_ids(self): + """Test requesting a DOI with multiple pointing IDs.""" + # Create multiple completed pointings + pointing_ids = [] + for _ in range(2): + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + pointing_ids.append(pointing_id) + + # Now request a DOI for them + doi_data = { + "ids": pointing_ids, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI_URL" in data + + def test_request_doi_with_graceid(self): + """Test requesting a DOI with a graceid.""" + graceid = self.KNOWN_GRACEIDS[0] + + # Create multiple completed pointings for the same graceid + for _ in range(2): + self.create_completed_pointing(graceid, self.admin_token) + + # Now request a DOI for all pointings with this graceid + doi_data = { + "graceid": graceid, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI_URL" in data + + def test_request_doi_with_existing_url(self): + """Test requesting a DOI with an existing DOI URL.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Now request a DOI with an existing URL + doi_data = { + "id": pointing_id, + "doi_url": "https://doi.org/10.5281/zenodo.example", + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI_URL" in data + assert data["DOI_URL"] == "https://doi.org/10.5281/zenodo.example" + + def test_request_doi_with_doi_group_id(self): + """Test requesting a DOI with a DOI group ID.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Get DOI author groups for the user + response = requests.get( + self.get_url("/doi_author_groups"), headers={"api_token": self.admin_token} + ) + + if response.status_code != status.HTTP_200_OK or len(response.json()) == 0: + pytest.skip("No DOI author groups available for testing") + + group_id = response.json()[0]["id"] + + # Now request a DOI with the group ID + doi_data = {"id": pointing_id, "doi_group_id": group_id} + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + # If the group has valid authors, this should succeed + if response.status_code == status.HTTP_200_OK: + data = response.json() + assert "DOI_URL" in data + else: + # The group might not have valid authors in test data + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "validation error" in response.json()["message"] + + def test_request_doi_without_creators(self): + """Test requesting a DOI without specifying creators.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Now request a DOI without specifying creators + doi_data = {"id": pointing_id} + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI_URL" in data + + def test_request_doi_with_invalid_creators(self): + """Test requesting a DOI with invalid creators.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Now request a DOI with invalid creators (missing affiliation) + doi_data = { + "id": pointing_id, + "creators": [ + { + "name": "Test Author" + # Missing affiliation + } + ], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Request validation error" in response.json()["message"] + + def test_request_doi_with_insufficient_params(self): + """Test requesting a DOI with insufficient parameters.""" + # Request a DOI without any identifier + doi_data = { + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}] + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Request validation error" in response.json()["message"] + + def test_request_doi_for_others_pointings(self): + """Test that user can only request DOIs for their own pointings.""" + # First create a completed pointing as admin + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Now try to request a DOI for it as a different user + doi_data = { + "id": pointing_id, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.user_token}, # Different user + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "No valid pointings found for DOI request" in response.json()["message"] + + def test_request_doi_for_planned_pointing(self): + """Test that DOI cannot be requested for planned pointings.""" + # Create a planned pointing (not completed) + pointing_data = { + "graceid": self.KNOWN_GRACEIDS[0], + "pointing": { + "ra": 123.456, + "dec": -12.345, + "instrumentid": 1, + "depth": 20.5, + "depth_unit": "ab_mag", + "time": ( + datetime.datetime.now() + datetime.timedelta(days=1) + ).isoformat(), + "status": "planned", # Not completed + "band": "r", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + pointing_id = response.json()["pointing_ids"][0] + + # Now try to request a DOI for the planned pointing + doi_data = { + "id": pointing_id, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "No valid pointings found for DOI request" in response.json()["message"] + + def test_get_doi_pointings(self): + """Test getting all pointings with DOIs.""" + # First create a completed pointing + pointing_id = self.create_completed_pointing( + self.KNOWN_GRACEIDS[0], self.admin_token + ) + + # Request a DOI for the pointing (will likely fail in test environment, but endpoint should work) + doi_data = { + "id": pointing_id, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + # Request a DOI for the pointing + doi_response = requests.post( + self.get_url("/request_doi"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert doi_response.status_code == status.HTTP_200_OK + + # In test environment, DOI creation may fail, but endpoint should still work + doi_data_result = doi_response.json() + assert "DOI_URL" in doi_data_result + assert "WARNINGS" in doi_data_result + + # Now get all DOI pointings + response = requests.get( + self.get_url("/doi_pointings"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointings" in data + assert isinstance(data["pointings"], list) + + # In test environment, DOI creation may fail, so pointing may not have DOI + # Just verify the endpoint works and returns the expected structure + pointing_ids = [p["id"] for p in data["pointings"]] + + # If DOI was successfully created, the pointing should be in the list + # If not, that's also acceptable in test environment + if doi_data_result.get("DOI_URL"): + assert pointing_id in pointing_ids + else: + # DOI creation failed (expected in test), so pointing won't be in DOI list + # This is acceptable - we've verified the endpoints work correctly + pass + + def test_get_doi_author_groups(self): + """Test getting DOI author groups.""" + response = requests.get( + self.get_url("/doi_author_groups"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Each group should have an id and name + for group in data: + assert "id" in group + assert "name" in group + + def test_get_doi_authors(self): + """Test getting DOI authors for a group.""" + # First get available groups + groups_response = requests.get( + self.get_url("/doi_author_groups"), headers={"api_token": self.admin_token} + ) + + if ( + groups_response.status_code != status.HTTP_200_OK + or len(groups_response.json()) == 0 + ): + pytest.skip("No DOI author groups available for testing") + + group_id = groups_response.json()[0]["id"] + + # Now get authors for this group + response = requests.get( + self.get_url(f"/doi_authors/{group_id}"), + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Each author should have name and affiliation + for author in data: + assert "name" in author + assert "affiliation" in author + assert "author_groupid" in author + assert author["author_groupid"] == group_id + + def test_get_doi_authors_for_others_group(self): + """Test that user can only access their own DOI author groups.""" + # First get available groups for admin + groups_response = requests.get( + self.get_url("/doi_author_groups"), headers={"api_token": self.admin_token} + ) + + if ( + groups_response.status_code != status.HTTP_200_OK + or len(groups_response.json()) == 0 + ): + pytest.skip("No DOI author groups available for testing") + + group_id = groups_response.json()[0]["id"] + + # Now try to get authors for this group as a different user + response = requests.get( + self.get_url(f"/doi_authors/{group_id}"), + headers={"api_token": self.user_token}, # Different user + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "You don't have permission" in response.json()["message"] + + def test_authentication_required(self): + """Test that authentication is required for all endpoints.""" + endpoints = ["/doi_pointings", "/doi_author_groups"] + + for endpoint in endpoints: + response = requests.get(self.get_url(endpoint)) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + # Test POST endpoint + response = requests.post(self.get_url("/request_doi"), json={"id": 1}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_event.py b/tests/fastapi/test_event.py new file mode 100644 index 00000000..daa0dcb9 --- /dev/null +++ b/tests/fastapi/test_event.py @@ -0,0 +1,502 @@ +""" +Test event endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestEventEndpoints: + """Test class for event-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + def test_get_candidate_events_no_params(self): + """Test getting candidate events without any parameters.""" + response = requests.get( + self.get_url("/candidate/event"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return candidates that exist in test data + for candidate in data: + assert "id" in candidate + assert "candidate_name" in candidate + + def test_get_candidate_events_by_user_id(self): + """Test getting candidate events filtered by user ID.""" + response = requests.get( + self.get_url("/candidate/event"), + params={"user_id": 1}, # Admin user ID + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned candidates should be submitted by user 1 + for candidate in data: + assert candidate["submitterid"] == 1 + + def test_post_candidate_event(self): + """Test creating a new candidate event.""" + candidate_data = { + "graceid": "S190425z", # Using a known GraceID + "candidate_name": "Test Candidate Event", + "ra": 123.456, + "dec": -12.345, + } + + response = requests.post( + self.get_url("/candidate/event"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "id" in data + assert "Candidate created successfully" in data["message"] + assert isinstance(data["id"], int) + + # Store the candidate ID for later tests + self.candidate_id = data["id"] + + def test_update_candidate_event(self): + """Test updating an existing candidate event.""" + # First create a candidate to update + if not hasattr(self, "candidate_id"): + self.test_post_candidate_event() + + update_data = { + "graceid": "S190425z", + "candidate_name": "Updated Candidate Event", + "ra": 124.567, + "dec": -13.456, + } + + response = requests.put( + self.get_url(f"/candidate/event/{self.candidate_id}"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "Candidate updated successfully" in data["message"] + + # Verify the update worked + response = requests.get( + self.get_url("/candidate/event"), + params={"id": self.candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 1 + assert data[0]["candidate_name"] == "Updated Candidate Event" + # ra and dec may be in position property, we validate them separately if available + if "ra" in data[0]: + assert abs(data[0]["ra"] - 124.567) < 0.001 + if "dec" in data[0]: + assert abs(data[0]["dec"] - (-13.456)) < 0.001 + + def test_delete_candidate_event(self): + """Test deleting a candidate event.""" + # First create a candidate to delete + if not hasattr(self, "candidate_id"): + self.test_post_candidate_event() + + response = requests.delete( + self.get_url(f"/candidate/event/{self.candidate_id}"), + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "Candidate deleted successfully" in data["message"] + + # Verify the candidate was deleted + response = requests.get( + self.get_url("/candidate/event"), + params={"id": self.candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 0 # Should be empty + + def test_post_candidate_event_as_different_user(self): + """Test creating a candidate event as a different user.""" + candidate_data = { + "graceid": "S190425z", # Required field + "candidate_name": "User Test Candidate", + "ra": 130.456, + "dec": -15.345, + } + + response = requests.post( + self.get_url("/candidate/event"), + json=candidate_data, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "id" in data + + # Store this candidate ID + self.user_candidate_id = data["id"] + + # Verify the user ID is correct + response = requests.get( + self.get_url("/candidate/event"), + params={"id": self.user_candidate_id}, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 1 + assert data[0]["submitterid"] == 2 # User token corresponds to user ID 2 + + def test_update_candidate_event_unauthorized(self): + """Test updating a candidate event created by a different user.""" + # First create a candidate as admin + candidate_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Admin Candidate", + "ra": 140.456, + "dec": -20.345, + } + + response = requests.post( + self.get_url("/candidate/event"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + admin_candidate_id = response.json()["id"] + + # Try to update as regular user + update_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Hijacked Candidate", + } + + response = requests.put( + self.get_url(f"/candidate/event/{admin_candidate_id}"), + json=update_data, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "Not authorized" in response.json()["message"] + + def test_delete_candidate_event_unauthorized(self): + """Test deleting a candidate event created by a different user.""" + # First create a candidate as admin + candidate_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Admin Protected Candidate", + "ra": 150.456, + "dec": -25.345, + } + + response = requests.post( + self.get_url("/candidate/event"), + json=candidate_data, + headers={"api_token": self.admin_token}, + ) + + admin_candidate_id = response.json()["id"] + + # Try to delete as regular user + response = requests.delete( + self.get_url(f"/candidate/event/{admin_candidate_id}"), + headers={"api_token": self.user_token}, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "Not authorized" in response.json()["message"] + + def test_candidate_event_unauthorized_access(self): + """Test that unauthorized requests are rejected.""" + # Request without API token + response = requests.get(self.get_url("/candidate/event")) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + # Request with invalid API token + response = requests.get( + self.get_url("/candidate/event"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_candidate_events_with_different_tokens(self): + """Test access with different valid API tokens.""" + # All authenticated users should be able to read candidates + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/candidate/event"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + def test_post_candidate_event_missing_required_fields(self): + """Test creating a candidate event with missing required fields.""" + incomplete_data = { + "candidate_name": "Incomplete Candidate", + # Missing graceid, ra, dec + } + + response = requests.post( + self.get_url("/candidate/event"), + json=incomplete_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST # Validation error + assert "missing" in response.json()["errors"][0]["params"]["type"] + assert "graceid" in response.json()["errors"][0]["params"]["field"] + + def test_update_candidate_event_nonexistent(self): + """Test updating a non-existent candidate event.""" + update_data = { + "graceid": "S190425z", # Required field + "candidate_name": "NonExistent Candidate", + } + + response = requests.put( + self.get_url("/candidate/event/99999"), # Non-existent ID + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Candidate not found" in response.json()["message"] + + def test_delete_candidate_event_nonexistent(self): + """Test deleting a non-existent candidate event.""" + response = requests.delete( + self.get_url("/candidate/event/99999"), # Non-existent ID + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Candidate not found" in response.json()["message"] + + +class TestEventAPIValidation: + """Test validation of event API endpoints.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_invalid_ra_dec(self): + """Test creating candidate event with invalid ra/dec values.""" + invalid_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Invalid Coordinates", + "ra": "not-a-number", + "dec": "also-not-a-number", + } + + response = requests.post( + self.get_url("/candidate/event"), + json=invalid_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST # Validation error + errors = response.json()["errors"] + assert any("ra" in str(e) for e in errors) + assert any("dec" in str(e) for e in errors) + + def test_out_of_range_coordinates(self): + """Test creating candidate event with out-of-range coordinates.""" + invalid_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Out of Range Coordinates", + "ra": 400.0, # RA should be 0-360 + "dec": -100.0, # Dec should be -90 to +90 + } + + response = requests.post( + self.get_url("/candidate/event"), + json=invalid_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST # Validation error + errors = response.json()["errors"] + assert any("ra" in str(e) for e in errors) or any( + "dec" in str(e) for e in errors + ) + + +class TestEventAPIIntegration: + """Integration tests for event API workflows.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_complete_event_workflow(self): + """Test complete workflow: create, read, update, delete candidate event.""" + # Step 1: Create candidate event + create_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Workflow Test Candidate", + "ra": 160.0, + "dec": -30.0, + } + + create_response = requests.post( + self.get_url("/candidate/event"), + json=create_data, + headers={"api_token": self.admin_token}, + ) + + assert create_response.status_code == status.HTTP_200_OK + candidate_id = create_response.json()["id"] + + # Step 2: Read the candidate + read_response = requests.get( + self.get_url("/candidate/event"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert read_response.status_code == status.HTTP_200_OK + candidate_data = read_response.json()[0] + assert candidate_data["candidate_name"] == "Workflow Test Candidate" + if "ra" in candidate_data: + assert candidate_data["ra"] == 160.0 + if "dec" in candidate_data: + assert candidate_data["dec"] == -30.0 + + # Step 3: Update the candidate + update_data = { + "graceid": "S190425z", # Required field + "candidate_name": "Updated Workflow Candidate", + } + + update_response = requests.put( + self.get_url(f"/candidate/event/{candidate_id}"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert update_response.status_code == status.HTTP_200_OK + assert "Candidate updated successfully" in update_response.json()["message"] + + # Verify the update + verify_response = requests.get( + self.get_url("/candidate/event"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + + updated_data = verify_response.json()[0] + assert updated_data["candidate_name"] == "Updated Workflow Candidate" + # Original fields should be unchanged + if "ra" in updated_data: + assert updated_data["ra"] == 160.0 + if "dec" in updated_data: + assert updated_data["dec"] == -30.0 + + # Step 4: Delete the candidate + delete_response = requests.delete( + self.get_url(f"/candidate/event/{candidate_id}"), + headers={"api_token": self.admin_token}, + ) + + assert delete_response.status_code == status.HTTP_200_OK + assert "Candidate deleted successfully" in delete_response.json()["message"] + + # Verify deletion + verify_deletion = requests.get( + self.get_url("/candidate/event"), + params={"id": candidate_id}, + headers={"api_token": self.admin_token}, + ) + + assert verify_deletion.status_code == status.HTTP_200_OK + assert len(verify_deletion.json()) == 0 # Should be empty + + def test_multiple_candidate_events_for_same_user(self): + """Test creating multiple candidate events for the same user.""" + # Create several candidates + candidate_ids = [] + for i in range(3): + data = { + "graceid": "S190425z", # Required field + "candidate_name": f"Multi Test Candidate {i}", + "ra": 170.0 + i, + "dec": -40.0 - i, + } + + response = requests.post( + self.get_url("/candidate/event"), + json=data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + candidate_ids.append(response.json()["id"]) + + # Verify all candidates were created for the admin user + response = requests.get( + self.get_url("/candidate/event"), + params={"user_id": 1}, # Admin user ID + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + candidates = response.json() + + # Check if all our created candidates are in the response + created_candidates = [c for c in candidates if c["id"] in candidate_ids] + assert len(created_candidates) == 3 + + # Clean up + for candidate_id in candidate_ids: + requests.delete( + self.get_url(f"/candidate/event/{candidate_id}"), + headers={"api_token": self.admin_token}, + ) + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_gw_alert.py b/tests/fastapi/test_gw_alert.py new file mode 100644 index 00000000..0a098e2d --- /dev/null +++ b/tests/fastapi/test_gw_alert.py @@ -0,0 +1,322 @@ +""" +Test GW alert endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import json +import datetime +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestGWAlertEndpoints: + """Test class for GW alert-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "GW190521"] + + def test_query_alerts_no_params(self): + """Test querying alerts without any parameters.""" + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return all alerts in test data + assert len(data) >= 4 # At least our known graceids + + def test_query_alerts_by_graceid(self): + """Test querying alerts by graceid.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) > 0 + # All returned alerts should have the specified graceid + for alert in data: + assert alert["graceid"] == graceid + + def test_query_alerts_by_alert_type(self): + """Test querying alerts by alert type.""" + # Test for common alert types + alert_types = ["Initial", "Update", "Retraction"] + + for alert_type in alert_types: + response = requests.get( + self.get_url("/query_alerts"), + params={"alert_type": alert_type}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Skip if no alerts of this type in test data + if len(data) > 0: + # All returned alerts should have the specified alert_type + for alert in data: + assert alert["alert_type"] == alert_type + + def test_query_alerts_graceid_and_alert_type(self): + """Test querying alerts by both graceid and alert type.""" + response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": "S190425z", "alert_type": "Initial"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Skip if no matching alerts in test data + if len(data) > 0: + # All returned alerts should match both parameters + for alert in data: + assert alert["graceid"] == "S190425z" + assert alert["alert_type"] == "Initial" + + def test_query_alerts_without_auth(self): + """Test that authentication is required.""" + response = requests.get(self.get_url("/query_alerts")) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_query_alerts_with_invalid_token(self): + """Test with invalid API token.""" + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_post_alert_as_admin(self): + """Test posting a new GW alert as admin.""" + alert_data = { + "graceid": f"TEST{datetime.datetime.now().strftime('%y%m%d%H%M%S')}", + "alternateid": "", + "role": "test", + "observing_run": "O4", + "description": "Test alert creation", + "alert_type": "Initial", + "far": 1e-9, + "group": "CBC", + "detectors": "H1,L1", + "prob_hasns": 0.95, + "prob_hasremenant": 0.9, + "prob_bns": 0.8, + "prob_nsbh": 0.1, + "prob_bbh": 0.05, + "prob_terrestrial": 0.05, + "skymap_fits_url": "https://example.com/skymap.fits", + "avgra": 123.456, + "avgdec": -12.345, + "time_of_signal": datetime.datetime.now().isoformat(), + "distance": 100.0, + "distance_error": 10.0, + } + + response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "graceid" in data + assert data["graceid"] == alert_data["graceid"] + assert data["role"] == "test" + assert data["alert_type"] == "Initial" + + def test_post_alert_as_non_admin(self): + """Test that only admin can post alerts.""" + alert_data = { + "graceid": f"TEST{datetime.datetime.now().strftime('%y%m%d%H%M%S')}", + "role": "test", + "alert_type": "Initial", + } + + response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.user_token}, # Non-admin user + ) + + # Should fail with 403 Forbidden + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_get_gw_skymap(self): + """Test getting a GW skymap FITS file.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/gw_skymap"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + # If skymap exists, should return 200, otherwise 404 + if response.status_code == status.HTTP_200_OK: + # Should return binary data with FITS header + assert response.headers["Content-Type"] == "application/fits" + assert response.headers["Content-Disposition"].startswith( + "attachment; filename=" + ) + assert len(response.content) > 0 + break # Found a valid skymap, no need to try others + else: + # Graceid might not have a skymap in test data + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Error in retrieving skymap file" in response.json()["message"] + + def test_get_gw_contour(self): + """Test getting GW contour data.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/gw_contour"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + # If contour exists, should return 200, otherwise 404 + if response.status_code == status.HTTP_200_OK: + # Should return JSON data + assert response.headers["Content-Type"] == "application/json" + # Try to parse as JSON to confirm it's valid + try: + json_data = response.json() + assert isinstance(json_data, dict) + except json.JSONDecodeError: + # If it's not valid JSON, the test should fail + assert False, "Response is not valid JSON" + break # Found a valid contour, no need to try others + else: + # Graceid might not have a contour in test data + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Error in retrieving Contour file" in response.json()["message"] + + def test_get_grb_moc_file(self): + """Test getting a GRB MOC file.""" + instruments = ["gbm", "lat", "bat"] + + for graceid in self.KNOWN_GRACEIDS: + for instrument in instruments: + response = requests.get( + self.get_url("/grb_moc_file"), + params={"graceid": graceid, "instrument": instrument}, + headers={"api_token": self.admin_token}, + ) + + # If MOC file exists, should return 200, otherwise 404 + if response.status_code == status.HTTP_200_OK: + # Should return JSON data + assert response.headers["Content-Type"] == "application/json" + # Try to parse as JSON to confirm it's valid + try: + json_data = response.json() + assert isinstance(json_data, dict) + except json.JSONDecodeError: + # If it's not valid JSON, the test should fail + assert False, "Response is not valid JSON" + return # Found a valid MOC file, test is complete + + # If we get here, no MOC files were found for any graceid/instrument combination + # This is expected in test data, so we'll skip this test + pytest.skip("No GRB MOC files found in test data") + + def test_get_grb_moc_file_invalid_instrument(self): + """Test getting a GRB MOC file with invalid instrument.""" + response = requests.get( + self.get_url("/grb_moc_file"), + params={"graceid": "S190425z", "instrument": "invalid"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + "Valid instruments are in ['gbm', 'lat', 'bat']" + in response.json()["message"] + ) + + def test_del_test_alerts_as_non_admin(self): + """Test that only admin can delete test alerts.""" + response = requests.post( + self.get_url("/del_test_alerts"), + headers={"api_token": self.user_token}, # Non-admin user + ) + + # Should fail with 403 Forbidden + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_alerts_with_different_tokens(self): + """Test alert queries with different valid API tokens.""" + # All authenticated users should be able to query alerts + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + def test_alert_data_format(self): + """Test that alert data is returned in the correct format.""" + response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + alerts = response.json() + + if len(alerts) == 0: + pytest.skip("No alerts found for S190425z in test data") + + alert = alerts[0] + + # Check required fields + required_fields = ["id", "graceid", "alert_type", "datecreated", "role"] + for field in required_fields: + assert field in alert + + # Check data types + assert isinstance(alert["id"], int) + assert isinstance(alert["graceid"], str) + assert isinstance(alert["alert_type"], str) + + # Make sure time fields are parseable as ISO 8601 + for time_field in ["datecreated", "time_of_signal", "timesent"]: + if time_field in alert and alert[time_field]: + try: + datetime.datetime.fromisoformat( + alert[time_field].replace("Z", "+00:00") + ) + except ValueError: + assert False, f"Time field {time_field} is not in ISO 8601 format" + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_gw_galaxy.py b/tests/fastapi/test_gw_galaxy.py new file mode 100644 index 00000000..b13f493c --- /dev/null +++ b/tests/fastapi/test_gw_galaxy.py @@ -0,0 +1,428 @@ +""" +Test GW galaxy endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestGWGalaxyEndpoints: + """Test class for GW galaxy-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + # Test timestamps (based on test data) + test_timestamp = "2019-04-25T08:18:05.123456" + + def test_get_event_galaxies_no_params(self): + """Test getting event galaxies without proper parameters.""" + response = requests.get( + self.get_url("/event_galaxies"), headers={"api_token": self.admin_token} + ) + + # Should fail with 422 - Missing required parameter + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_get_event_galaxies_with_graceid(self): + """Test getting event galaxies with valid graceid.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/event_galaxies"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Note: Test may pass even if empty list - depends on test data + + def test_get_event_galaxies_with_invalid_timesent(self): + """Test getting event galaxies with invalid timestamp.""" + response = requests.get( + self.get_url("/event_galaxies"), + params={ + "graceid": self.KNOWN_GRACEIDS[0], + "timesent_stamp": "invalid-timestamp", + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Error parsing date" in response.json()["message"] + + def test_get_event_galaxies_with_nonexistent_timesent(self): + """Test getting event galaxies with timestamp that doesn't match any alert.""" + response = requests.get( + self.get_url("/event_galaxies"), + params={ + "graceid": self.KNOWN_GRACEIDS[0], + "timesent_stamp": "2099-01-01T12:00:00.000000", # Future date + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid 'timesent_stamp' for event" in response.json()["message"] + + def test_get_event_galaxies_with_score_filters(self): + """Test getting event galaxies with score filters.""" + graceid = self.KNOWN_GRACEIDS[0] + + # First try to post some galaxies to ensure test data + self.post_test_galaxy_data(graceid) + + # Now query with score filters + response = requests.get( + self.get_url("/event_galaxies"), + params={"graceid": graceid, "score_gt": 0.5, "score_lt": 1.0}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Check that all returned galaxies have scores in the specified range + for galaxy in data: + if "score" in galaxy: + assert 0.5 <= galaxy["score"] <= 1.0 + + def test_post_event_galaxies(self): + """Test posting event galaxies.""" + graceid = self.KNOWN_GRACEIDS[0] + response = self.post_test_galaxy_data(graceid) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "Successful adding of" in data["message"] + assert graceid in data["message"] + assert "List ID:" in data["message"] + assert len(data["errors"]) == 0 + assert len(data["warnings"]) == 0 + + def test_post_event_galaxies_with_doi(self): + """Test posting event galaxies with DOI request.""" + graceid = self.KNOWN_GRACEIDS[0] + + # Get timestamp from a valid alert + alert_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + if alert_response.status_code != 200 or len(alert_response.json()) == 0: + pytest.skip(f"No alerts found for {graceid} in test data") + + alert = alert_response.json()[0] + timesent_stamp = alert.get("timesent") + + if not timesent_stamp: + pytest.skip( + f"Alert for {graceid} does not have timesent field in test data" + ) + + # Create galaxy data with DOI request + galaxy_data = { + "graceid": graceid, + "timesent_stamp": timesent_stamp, + "groupname": "Test Group", + "reference": "Test Reference", + "request_doi": True, + "creators": [{"name": "Test Author", "affiliation": "Test University"}], + "galaxies": [ + { + "name": "NGC 123", + "ra": 123.456, + "dec": -12.345, + "score": 0.9, + "rank": 1, + "info": {"redshift": 0.01}, + }, + { + "name": "NGC 456", + "position": "POINT(45.678 -67.890)", + "score": 0.8, + "rank": 2, + "info": {"redshift": 0.02}, + }, + ], + } + + response = requests.post( + self.get_url("/event_galaxies"), + json=galaxy_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "Successful adding of" in data["message"] + # Note: DOI creation might fail in test environment + + def test_post_event_galaxies_with_invalid_data(self): + """Test posting event galaxies with invalid data.""" + graceid = self.KNOWN_GRACEIDS[0] + + # Get timestamp from a valid alert + alert_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + if alert_response.status_code != 200 or len(alert_response.json()) == 0: + pytest.skip(f"No alerts found for {graceid} in test data") + + alert = alert_response.json()[0] + timesent_stamp = alert.get("timesent") + + if not timesent_stamp: + pytest.skip( + f"Alert for {graceid} does not have timesent field in test data" + ) + + # Create galaxy data with invalid galaxy position + galaxy_data = { + "graceid": graceid, + "timesent_stamp": timesent_stamp, + "galaxies": [ + { + "name": "Invalid Galaxy", + # Missing both position and ra/dec + "score": 0.9, + "rank": 1, + } + ], + } + + response = requests.post( + self.get_url("/event_galaxies"), + json=galaxy_data, + headers={"api_token": self.admin_token}, + ) + + # Should return 200 but with errors + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_remove_event_galaxies(self): + """Test removing event galaxies.""" + # First post some galaxies to get a list ID + graceid = self.KNOWN_GRACEIDS[0] + post_response = self.post_test_galaxy_data(graceid) + + if post_response.status_code != 200: + pytest.skip("Failed to post test galaxy data") + + post_data = post_response.json() + list_id = int(post_data["message"].split("List ID: ")[1].strip()) + + # Now try to delete them + response = requests.delete( + self.get_url("/remove_event_galaxies"), + params={"listid": list_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "message" in data + assert "Successfully deleted your galaxy list" in data["message"] + + # Verify deletion by trying to get the galaxies + get_response = requests.get( + self.get_url("/event_galaxies"), + params={"graceid": graceid, "listid": list_id}, + headers={"api_token": self.admin_token}, + ) + + assert get_response.status_code == status.HTTP_200_OK + data = get_response.json() + assert len(data) == 0 # Should be empty after deletion + + def test_remove_event_galaxies_unauthorized(self): + """Test that user can only remove their own galaxy lists.""" + # First post some galaxies with admin token + graceid = self.KNOWN_GRACEIDS[0] + post_response = self.post_test_galaxy_data(graceid) + + if post_response.status_code != 200: + pytest.skip("Failed to post test galaxy data") + + post_data = post_response.json() + list_id = int(post_data["message"].split("List ID: ")[1].strip()) + + # Now try to delete them with a different user token + response = requests.delete( + self.get_url("/remove_event_galaxies"), + params={"listid": list_id}, + headers={"api_token": self.user_token}, # Different user + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert ( + "You can only delete information related to your API token" + in response.json()["message"] + ) + + def test_remove_nonexistent_event_galaxies(self): + """Test trying to remove nonexistent event galaxies.""" + response = requests.delete( + self.get_url("/remove_event_galaxies"), + params={"listid": 99999}, # Nonexistent ID + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "No galaxies found" in response.json()["message"] + + def test_get_glade_galaxies_no_params(self): + """Test getting GLADE galaxies without parameters.""" + response = requests.get( + self.get_url("/glade"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return some galaxies from GLADE catalog + assert len(data) > 0 + + def test_get_glade_galaxies_by_position(self): + """Test getting GLADE galaxies near a position.""" + response = requests.get( + self.get_url("/glade"), + params={"ra": 123.456, "dec": -12.345}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return galaxies sorted by distance from the specified position + + def test_get_glade_galaxies_by_name(self): + """Test getting GLADE galaxies by name.""" + response = requests.get( + self.get_url("/glade"), + params={"name": "NGC"}, # Common prefix for galaxy names + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + # Should return galaxies with names containing "NGC" + if len(data) > 0: + for galaxy in data: + # Check if any of the name fields contain "NGC" + has_name_match = False + for name_field in [ + "_2mass_name", + "gwgc_name", + "hyperleda_name", + "sdssdr12_name", + ]: + if ( + name_field in galaxy + and galaxy[name_field] + and "NGC" in galaxy[name_field] + ): + has_name_match = True + break + assert has_name_match + + def test_authentication_required(self): + """Test that authentication is required for all endpoints.""" + endpoints = ["/event_galaxies?graceid=S190425z", "/glade"] + + for endpoint in endpoints: + response = requests.get(self.get_url(endpoint)) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def post_test_galaxy_data(self, graceid): + """Helper method to post test galaxy data.""" + # Get timestamp from a valid alert + alert_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + if alert_response.status_code != 200 or len(alert_response.json()) == 0: + pytest.skip(f"No alerts found for {graceid} in test data") + + alert = alert_response.json()[0] + timesent_stamp = alert.get("timesent") + + if not timesent_stamp: + pytest.skip( + f"Alert for {graceid} does not have timesent field in test data" + ) + + # Create test galaxy data + galaxy_data = { + "graceid": graceid, + "timesent_stamp": timesent_stamp, + "groupname": "Test Group", + "reference": "Test Reference", + "galaxies": [ + { + "name": "Test Galaxy 1", + "ra": 123.456, + "dec": -12.345, + "score": 0.9, + "rank": 1, + "info": {"redshift": 0.01}, + }, + { + "name": "Test Galaxy 2", + "position": "POINT(45.678 -67.890)", + "score": 0.8, + "rank": 2, + "info": {"redshift": 0.02}, + }, + { + "name": "Test Galaxy 3", + "ra": 200.123, + "dec": 30.456, + "score": 0.7, + "rank": 3, + "info": {"redshift": 0.03}, + }, + ], + } + + print(f"Galaxy data for {graceid}: {galaxy_data} - POSTing...") + return requests.post( + self.get_url("/event_galaxies"), + json=galaxy_data, + headers={"api_token": self.admin_token}, + ) + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_health.py b/tests/fastapi/test_health.py new file mode 100644 index 00000000..2ea27ae6 --- /dev/null +++ b/tests/fastapi/test_health.py @@ -0,0 +1,82 @@ +""" +Test health check endpoints with real requests to the FastAPI application. +""" + +import os +import requests +import datetime +import pytest +from fastapi import status + + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") + + +class TestHealthEndpoints: + """Test class for health check API endpoints.""" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{endpoint}" + + def test_health_endpoint(self): + """Test the basic health check endpoint.""" + response = requests.get(self.get_url("/health")) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "status" in data + assert data["status"] == "ok" + assert "time" in data + + # Verify time is recent (within last minute) + time_str = data["time"] + time = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00")) + now = datetime.datetime.now(datetime.timezone.utc) + difference = now - time.astimezone(datetime.timezone.utc) + assert difference.total_seconds() < 60 # Within the last minute + + def test_service_status_endpoint(self): + """Test the detailed service status endpoint.""" + response = requests.get(self.get_url("/service-status")) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + # Check that the response has the expected structure + assert "database_status" in data + assert "redis_status" in data + assert "details" in data + assert "database" in data["details"] + assert "redis" in data["details"] + + # Check database details + database = data["details"]["database"] + assert "host" in database + assert "port" in database + assert "name" in database + + # Check redis details + redis = data["details"]["redis"] + assert "host" in redis + assert "port" in redis + assert "url" in redis + + # If database connection is successful, status should be "connected" + # If not, there should be an error message + if data["database_status"] == "connected": + assert data["database_status"] == "connected" + else: + assert "error" in database + + # Similarly for Redis + if data["redis_status"] == "connected": + assert data["redis_status"] == "connected" + else: + assert "error" in redis + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_icecube.py b/tests/fastapi/test_icecube.py new file mode 100644 index 00000000..cc4f9921 --- /dev/null +++ b/tests/fastapi/test_icecube.py @@ -0,0 +1,241 @@ +""" +Test IceCube endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import datetime +import uuid +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestIceCubeEndpoints: + """Test class for IceCube-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_post_icecube_notice_as_admin(self): + """Test posting an IceCube notice as admin.""" + # Generate a unique reference ID + ref_id = f"IceCube-{uuid.uuid4()}" + notice_data = { + "ref_id": ref_id, + "graceid": "S190425z", # Use a known GraceID from test data + "alert_datetime": datetime.datetime.now().isoformat(), + "observation_start": ( + datetime.datetime.now() - datetime.timedelta(hours=1) + ).isoformat(), + "observation_stop": datetime.datetime.now().isoformat(), + "pval_generic": 0.01, + "pval_bayesian": 0.02, + "most_probable_direction_ra": 123.456, + "most_probable_direction_dec": -12.345, + "flux_sens_low": 1e-10, + "flux_sens_high": 1e-9, + "sens_energy_range_low": 100, + "sens_energy_range_high": 1000, + } + + events_data = [ + { + "event_dt": 0.5, + "ra": 123.456, + "dec": -12.345, + "containment_probability": 0.9, + "event_pval_generic": 0.015, + "event_pval_bayesian": 0.025, + "ra_uncertainty": 0.5, + "uncertainty_shape": "circle", + }, + { + "event_dt": 1.0, + "ra": 124.567, + "dec": -13.456, + "containment_probability": 0.85, + "event_pval_generic": 0.02, + "event_pval_bayesian": 0.03, + "ra_uncertainty": 0.6, + "uncertainty_shape": "circle", + }, + ] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "icecube_notice" in result + assert "icecube_notice_events" in result + assert result["icecube_notice"]["ref_id"] == ref_id + assert len(result["icecube_notice_events"]) == 2 + + def test_post_icecube_notice_as_non_admin(self): + """Test that only admin can post IceCube notices.""" + ref_id = f"IceCube-{uuid.uuid4()}" + + notice_data = { + "ref_id": ref_id, + "graceid": "S190425z", + "alert_datetime": datetime.datetime.now().isoformat(), + } + + events_data = [{"event_dt": 0.5, "ra": 123.456, "dec": -12.345}] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.user_token}, # Non-admin user + ) + + # Should fail with 403 Forbidden + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_post_duplicate_icecube_notice(self): + """Test posting a duplicate IceCube notice.""" + # First post a notice + ref_id = f"IceCube-{uuid.uuid4()}" + + notice_data = { + "ref_id": ref_id, + "graceid": "S190425z", + "alert_datetime": datetime.datetime.now().isoformat(), + } + + events_data = [{"event_dt": 0.5, "ra": 123.456, "dec": -12.345}] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + + # Now post again with the same ref_id + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "event already exists" in result["icecube_notice"]["message"] + + def test_post_icecube_notice_invalid_graceid(self): + """Test posting an IceCube notice with an invalid GraceID.""" + ref_id = f"IceCube-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + + notice_data = { + "ref_id": ref_id, + "graceid": "INVALID123", # Invalid GraceID + "alert_datetime": datetime.datetime.now().isoformat(), + } + + events_data = [{"event_dt": 0.5, "ra": 123.456, "dec": -12.345}] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.admin_token}, + ) + + # Note: The endpoint might accept invalid GraceIDs depending on implementation + # If it validates GraceIDs, the response should indicate an error + if response.status_code != 200: + assert response.status_code in [400, 404, 422, 500] + else: + # If it accepts the invalid GraceID, verify the notice was created + result = response.json() + assert "icecube_notice" in result + assert result["icecube_notice"]["ref_id"] == ref_id + assert result["icecube_notice"]["graceid"] == "INVALID123" + + def test_post_icecube_notice_missing_fields(self): + """Test posting an IceCube notice with missing required fields.""" + # Post with minimal required fields according to schema + ref_id = f"IceCube-{uuid.uuid4()}" + + notice_data = { + "ref_id": ref_id, + "graceid": "S190425z", + # Missing other fields, but they are Optional in the schema + } + + events_data = [] # No events + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert "icecube_notice" in result + assert result["icecube_notice"]["ref_id"] == ref_id + assert len(result["icecube_notice_events"]) == 0 + + def test_post_icecube_notice_without_auth(self): + """Test that authentication is required.""" + ref_id = f"IceCube-{uuid.uuid4()}" + + notice_data = {"ref_id": ref_id, "graceid": "S190425z"} + + events_data = [] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post(self.get_url("/post_icecube_notice"), json=data) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_post_icecube_notice_with_invalid_token(self): + """Test with invalid API token.""" + ref_id = f"IceCube-{uuid.uuid4()}" + + notice_data = {"ref_id": ref_id, "graceid": "S190425z"} + + events_data = [] + + data = {"notice_data": notice_data, "events_data": events_data} + + response = requests.post( + self.get_url("/post_icecube_notice"), + json=data, + headers={"api_token": self.invalid_token}, + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_instrument.py b/tests/fastapi/test_instrument.py new file mode 100644 index 00000000..5c0d8d78 --- /dev/null +++ b/tests/fastapi/test_instrument.py @@ -0,0 +1,607 @@ +""" +Instrument API tests using regular HTTP requests. +These tests hit the actual API endpoints running on the server. +""" + +import pytest +import requests +import os +from server.core.enums.instrumenttype import InstrumentType +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestInstrumentAPI: + """Test suite for instrument-related API endpoints using HTTP requests.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_get_all_instruments(self): + """Test getting all instruments without filters.""" + response = requests.get( + self.get_url("/instruments"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + assert len(instruments) == 3 # We have 3 test instruments + assert all("id" in inst for inst in instruments) + assert all("instrument_name" in inst for inst in instruments) + + def test_get_instrument_by_id(self): + """Test getting a specific instrument by ID.""" + response = requests.get( + self.get_url("/instruments?id=1"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + assert len(instruments) == 1 + assert instruments[0]["id"] == 1 + assert instruments[0]["instrument_name"] == "Test Optical Telescope" + assert instruments[0]["nickname"] == "TOT" + + def test_get_instruments_by_ids(self): + """Test getting multiple instruments by IDs.""" + response = requests.get( + self.get_url("/instruments?ids=[1,2]"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + assert len(instruments) == 2 + ids = [inst["id"] for inst in instruments] + assert 1 in ids + assert 2 in ids + + def test_get_instruments_by_name_filter(self): + """Test getting instruments by name filter.""" + response = requests.get( + self.get_url("/instruments?name=Optical"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + assert len(instruments) == 1 + assert "Optical" in instruments[0]["instrument_name"] + + def test_get_instruments_by_type(self): + """Test getting instruments by type.""" + response = requests.get( + self.get_url(f"/instruments?type={InstrumentType.photometric.value}"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + # We have 2 photometric instruments + assert len(instruments) == 2 + assert all( + inst["instrument_type"] == InstrumentType.photometric.value + for inst in instruments + ) + + def test_get_instruments_with_invalid_ids_format(self): + """Test error handling for invalid IDs format.""" + response = requests.get( + self.get_url("/instruments?ids=invalid"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid ids format" in response.json()["message"] + + def test_get_instruments_without_auth(self): + """Test that authentication is required.""" + response = requests.get(self.get_url("/instruments")) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "API token is required" in response.json()["message"] + + def test_get_instruments_with_invalid_token(self): + """Test with invalid API token.""" + response = requests.get( + self.get_url("/instruments"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "Invalid API token" in response.json()["message"] + + def test_get_footprints_all(self): + """Test getting all footprints.""" + response = requests.get( + self.get_url("/footprints"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + footprints = response.json() + assert len(footprints) == 3 # We have 3 test footprints + assert all("id" in fp for fp in footprints) + assert all("instrumentid" in fp for fp in footprints) + assert all("footprint" in fp for fp in footprints) + + def test_get_footprints_by_id(self): + """Test getting footprints for a specific instrument ID.""" + response = requests.get( + self.get_url("/footprints?id=1"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + footprints = response.json() + assert len(footprints) == 1 + assert footprints[0]["instrumentid"] == 1 + # Check that footprint is returned as WKT string + assert "POLYGON" in footprints[0]["footprint"] + + def test_get_footprints_by_name(self): + """Test getting footprints by instrument name.""" + response = requests.get( + self.get_url("/footprints?name=Optical"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + footprints = response.json() + assert len(footprints) == 1 + assert footprints[0]["instrumentid"] == 1 + + def test_create_instrument(self): + """Test creating a new instrument.""" + new_instrument = { + "instrument_name": "New Test Telescope", + "nickname": "NTT", + "instrument_type": InstrumentType.photometric.value, + } + response = requests.post( + self.get_url("/instruments"), + json=new_instrument, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + created = response.json() + assert created["instrument_name"] == new_instrument["instrument_name"] + assert created["nickname"] == new_instrument["nickname"] + assert created["instrument_type"] == new_instrument["instrument_type"] + assert created["submitterid"] == 1 # Admin user ID + assert "id" in created + assert "datecreated" in created + + # Store the created instrument ID for cleanup in other tests + self._created_instrument_id = created["id"] + + def test_create_instrument_as_different_user(self): + """Test creating an instrument as a different user.""" + new_instrument = { + "instrument_name": "User Test Telescope", + "nickname": "UTT", + "instrument_type": InstrumentType.spectroscopic.value, + } + response = requests.post( + self.get_url("/instruments"), + json=new_instrument, + headers={"api_token": self.user_token}, + ) + assert response.status_code == status.HTTP_200_OK + created = response.json() + assert created["submitterid"] == 2 # Test user ID + + def test_create_instrument_without_auth(self): + """Test that authentication is required for creation.""" + new_instrument = { + "instrument_name": "Unauthorized Telescope", + "nickname": "UT", + "instrument_type": InstrumentType.photometric.value, + } + response = requests.post(self.get_url("/instruments"), json=new_instrument) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_create_footprint(self): + """Test creating a new footprint for an instrument.""" + # First, create an instrument to add footprint to + new_instrument = { + "instrument_name": "Footprint Test Telescope 5535", + "nickname": "FTT 5535", + "instrument_type": InstrumentType.photometric.value, + } + inst_response = requests.post( + self.get_url("/instruments"), + json=new_instrument, + headers={"api_token": self.admin_token}, + ) + assert inst_response.status_code == status.HTTP_200_OK + instrument_id = inst_response.json()["id"] + + # Now create a footprint + new_footprint = { + "instrumentid": instrument_id, + "footprint": "POLYGON((-3 -3, 3 -3, 3 3, -3 3, -3 -3))", + } + response = requests.post( + self.get_url("/footprints"), + json=new_footprint, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + created = response.json() + assert created["instrumentid"] == instrument_id + assert "POLYGON" in created["footprint"] + + def test_create_footprint_for_nonexistent_instrument(self): + """Test creating footprint for non-existent instrument.""" + new_footprint = { + "instrumentid": 9999, # Non-existent ID + "footprint": "POLYGON((-1 -1, 1 -1, 1 1, -1 1, -1 -1))", + } + response = requests.post( + self.get_url("/footprints"), + json=new_footprint, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in response.json()["message"] + + def test_create_footprint_for_others_instrument(self): + """Test that users can't add footprints to instruments they don't own.""" + # Try to add footprint to instrument with ID 1 (owned by admin) using user token + new_footprint = { + "instrumentid": 1, + "footprint": "POLYGON((-1 -1, 1 -1, 1 1, -1 1, -1 -1))", + } + response = requests.post( + self.get_url("/footprints"), + json=new_footprint, + headers={"api_token": self.user_token}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "don't have permission" in response.json()["message"] + + def test_create_footprint_without_auth(self): + """Test that authentication is required for footprint creation.""" + new_footprint = { + "instrumentid": 1, + "footprint": "POLYGON((-1 -1, 1 -1, 1 1, -1 1, -1 -1))", + } + response = requests.post(self.get_url("/footprints"), json=new_footprint) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_instruments_with_complex_query(self): + """Test complex query combining multiple filters.""" + # First create some test data + photometric_instruments = [ + { + "instrument_name": f"Photometric {i}", + "nickname": f"P{i}", + "instrument_type": InstrumentType.photometric.value, + } + for i in range(2) + ] + for inst in photometric_instruments: + requests.post( + self.get_url("/instruments"), + json=inst, + headers={"api_token": self.admin_token}, + ) + + # Query by type and name pattern + response = requests.get( + self.get_url( + f"/instruments?type={InstrumentType.photometric.value}&name=Photometric" + ), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + instruments = response.json() + assert len(instruments) >= 2 # At least the ones we just created + assert all("Photometric" in inst["instrument_name"] for inst in instruments) + + def test_instrument_data_format(self): + """Test that instrument data is returned in the correct format.""" + response = requests.get( + self.get_url("/instruments?id=1"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + instrument = response.json()[0] + + # Check all required fields are present + required_fields = [ + "id", + "instrument_name", + "instrument_type", + "datecreated", + "submitterid", + ] + for field in required_fields: + assert field in instrument + + # Check data types + assert isinstance(instrument["id"], int) + assert isinstance(instrument["instrument_name"], str) + assert isinstance(instrument["instrument_type"], int) + assert isinstance(instrument["submitterid"], int) + + # Optional fields + if "nickname" in instrument: + assert isinstance(instrument["nickname"], str) + + def test_footprint_data_format(self): + """Test that footprint data is returned in the correct format.""" + response = requests.get( + self.get_url("/footprints?id=1"), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + footprint = response.json()[0] + + # Check all required fields + required_fields = ["id", "instrumentid", "footprint"] + for field in required_fields: + assert field in footprint + + # Check data types + assert isinstance(footprint["id"], int) + assert isinstance(footprint["instrumentid"], int) + assert isinstance(footprint["footprint"], str) + assert footprint["footprint"].startswith("POLYGON") + + +class TestInstrumentAPIValidation: + """Test validation of instrument API endpoints.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_invalid_instrument_type(self): + """Test creating instrument with invalid type.""" + invalid_instrument = { + "instrument_name": "Invalid Type Telescope", + "nickname": "ITT", + "instrument_type": 999, # Invalid type + } + response = requests.post( + self.get_url("/instruments"), + json=invalid_instrument, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST # Validation error + assert "Input should be" in response.json()["errors"][0]["message"] + + def test_missing_required_fields(self): + """Test creating instrument with missing required fields.""" + incomplete_instrument = { + "nickname": "ITT" + # Missing instrument_name and instrument_type + } + response = requests.post( + self.get_url("/instruments"), + json=incomplete_instrument, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + error_fields = [field["params"]["field"] for field in response.json()["errors"]] + assert "instrument_name" in str(error_fields) + assert "instrument_type" in str(error_fields) + + def test_invalid_footprint_format(self): + """Test creating footprint with invalid WKT format.""" + invalid_footprint = {"instrumentid": 1, "footprint": "INVALID WKT STRING"} + response = requests.post( + self.get_url("/footprints"), + json=invalid_footprint, + headers={"api_token": self.admin_token}, + ) + # This should fail at the database level + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid WKT format" in response.json()["errors"][0]["message"] + + def test_empty_name_filter(self): + """Test behavior with empty name filter.""" + response = requests.get( + self.get_url("/instruments?name="), headers={"api_token": self.admin_token} + ) + assert response.status_code == status.HTTP_200_OK + # Should return all instruments since empty filter matches all + instruments = response.json() + assert len(instruments) >= 3 # At least our test instruments + + +class TestInstrumentAPIPermissions: + """Test permission-related aspects of instrument API endpoints.""" + + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_all_users_can_read_instruments(self): + """Test that all users can read instruments.""" + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/instruments"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + def test_all_users_can_read_footprints(self): + """Test that all users can read footprints.""" + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/footprints"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + def test_all_users_can_create_instruments(self): + """Test that all authenticated users can create instruments.""" + for i, token in enumerate( + [self.admin_token, self.user_token, self.scientist_token] + ): + instrument = { + "instrument_name": f"User{i} Telescope", + "nickname": f"U{i}T", + "instrument_type": InstrumentType.photometric.value, + } + response = requests.post( + self.get_url("/instruments"), + json=instrument, + headers={"api_token": token}, + ) + assert response.status_code == status.HTTP_200_OK + + def test_footprint_creation_permission(self): + """Test that users can only add footprints to their own instruments.""" + # Create an instrument as user + instrument = { + "instrument_name": "User Owned Telescope", + "nickname": "UOT", + "instrument_type": InstrumentType.photometric.value, + } + response = requests.post( + self.get_url("/instruments"), + json=instrument, + headers={"api_token": self.user_token}, + ) + assert response.status_code == status.HTTP_200_OK + user_instrument_id = response.json()["id"] + + # User should be able to add footprint to their own instrument + footprint = { + "instrumentid": user_instrument_id, + "footprint": "POLYGON((-1 -1, 1 -1, 1 1, -1 1, -1 -1))", + } + response = requests.post( + self.get_url("/footprints"), + json=footprint, + headers={"api_token": self.user_token}, + ) + assert response.status_code == status.HTTP_200_OK + + # Admin should NOT be able to add footprint to user's instrument + response = requests.post( + self.get_url("/footprints"), + json=footprint, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestInstrumentAPIIntegration: + """Integration tests that test complete workflows.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_complete_instrument_creation_workflow(self): + """Test complete workflow of creating instrument and adding footprints.""" + # Step 1: Create instrument + new_instrument = { + "instrument_name": "Integration Test Telescope", + "nickname": "ITT", + "instrument_type": InstrumentType.photometric.value, + } + response = requests.post( + self.get_url("/instruments"), + json=new_instrument, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + instrument = response.json() + instrument_id = instrument["id"] + + # Step 2: Verify instrument was created + response = requests.get( + self.get_url(f"/instruments?id={instrument_id}"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 1 + assert ( + response.json()[0]["instrument_name"] == new_instrument["instrument_name"] + ) + + # Step 3: Add footprint to instrument + footprint = { + "instrumentid": instrument_id, + "footprint": "POLYGON((-2 -2, 2 -2, 2 2, -2 2, -2 -2))", + } + response = requests.post( + self.get_url("/footprints"), + json=footprint, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + created_footprint = response.json() + + # Step 4: Verify footprint was created + response = requests.get( + self.get_url(f"/footprints?id={instrument_id}"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + footprints = response.json() + assert len(footprints) == 1 + assert footprints[0]["instrumentid"] == instrument_id + assert "POLYGON" in footprints[0]["footprint"] + + def test_query_instruments_by_multiple_criteria(self): + """Test querying instruments with multiple filters.""" + # Create test instruments + test_instruments = [ + { + "instrument_name": "Multi Test Optical 1", + "nickname": "MTO1", + "instrument_type": InstrumentType.photometric.value, + }, + { + "instrument_name": "Multi Test Optical 2", + "nickname": "MTO2", + "instrument_type": InstrumentType.photometric.value, + }, + { + "instrument_name": "Multi Test Spectro", + "nickname": "MTS", + "instrument_type": InstrumentType.spectroscopic.value, + }, + ] + + created_ids = [] + for inst in test_instruments: + response = requests.post( + self.get_url("/instruments"), + json=inst, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + created_ids.append(response.json()["id"]) + + # Query by type + response = requests.get( + self.get_url(f"/instruments?type={InstrumentType.photometric.value}"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + photometric_insts = response.json() + assert ( + len([inst for inst in photometric_insts if inst["id"] in created_ids]) == 2 + ) + + # Query by name pattern + response = requests.get( + self.get_url("/instruments?name=Multi Test"), + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + named_insts = response.json() + assert len([inst for inst in named_insts if inst["id"] in created_ids]) >= 3 + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_pointing.py b/tests/fastapi/test_pointing.py new file mode 100644 index 00000000..7bce5ab1 --- /dev/null +++ b/tests/fastapi/test_pointing.py @@ -0,0 +1,807 @@ +""" +Test pointing endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os + +import requests + +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestPointingEndpoints: + """Test class for pointing-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + # Known instrument IDs from test data + TEST_INSTRUMENTS = { + 1: {"name": "Test Optical Telescope", "nickname": "TOT", "type": "photometric"}, + 2: { + "name": "Test X-ray Observatory", + "nickname": "TXO", + "type": "spectroscopic", + }, + 3: {"name": "Mock Radio Dish", "nickname": "MRD", "type": "photometric"}, + } + + def test_get_pointings_no_params(self): + """Test getting pointings without any parameters.""" + response = requests.get( + self.get_url("/pointings"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return all pointings the user has access to + assert len(data) >= 5 # We have at least 5 pointings from user 1 + + def test_get_pointings_by_graceid_s190425z(self): + """Test getting pointings filtered by graceid S190425z.""" + response = requests.get( + self.get_url("/pointings"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) >= 2 # Should find pointings 1 and 2 linked to S190425z + + def test_get_pointings_by_multiple_graceids(self): + """Test getting pointings filtered by multiple graceids.""" + response = requests.get( + self.get_url("/pointings"), + params={"graceids": "S190425z,S190426c"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) >= 4 # Should find pointings from both events + + def test_get_pointing_by_id(self): + """Test getting a specific pointing by ID.""" + response = requests.get( + self.get_url("/pointings"), + params={"id": 1}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["id"] == 1 + assert data[0]["status"] == "completed" + assert data[0]["band"] == "r" # band enum value 11 = r + + def test_get_pointings_by_multiple_ids(self): + """Test getting pointings filtered by multiple IDs.""" + response = requests.get( + self.get_url("/pointings"), + params={"ids": "1,2,3"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) >= 2 # Should find pointings 1, 2, and 3 + pointing_ids = [p["id"] for p in data] + assert 1 in pointing_ids + assert 2 in pointing_ids + + def test_get_pointings_by_status_completed(self): + """Test getting pointings with completed status.""" + """Test getting pointings filtered by multiple IDs.""" + response = requests.get( + self.get_url("/pointings"), + params={"status": "completed"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned pointings should have completed status + for pointing in data: + assert pointing.get("status") == "completed" + assert len(data) >= 2 # Should find pointings 1 and 3 + + def test_get_pointings_by_status_planned(self): + """Test getting pointings with planned status.""" + response = requests.get( + self.get_url("/pointings"), + params={"status": "planned"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned pointings should have planned status + for pointing in data: + assert pointing.get("status") == "planned" + # Should find pointing 2 and others + + def test_get_pointings_by_status_cancelled(self): + """Test getting pointings with cancelled status.""" + response = requests.get( + self.get_url("/pointings"), + params={"status": "cancelled"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned pointings should have cancelled status + for pointing in data: + assert pointing.get("status") == "cancelled" + # Should find pointing 4 + + def test_get_pointings_by_multiple_statuses(self): + """Test getting pointings filtered by multiple statuses.""" + response = requests.get( + self.get_url("/pointings"), + params={"statuses": "completed, planned"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should include both completed and planned pointings + statuses = [p.get("status") for p in data] + assert "completed" in statuses + assert "planned" in statuses + + def test_get_pointings_by_band_r(self): + """Test getting pointings filtered by r band.""" + response = requests.get( + self.get_url("/pointings"), + params={"band": "r"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings with r band (enum 11) + for pointing in data: + assert pointing.get("band") == "r" + + def test_get_pointings_by_band_g(self): + """Test getting pointings filtered by g band.""" + response = requests.get( + self.get_url("/pointings"), + params={"band": "g"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings with g band (enum 10) + for pointing in data: + assert pointing.get("band") == "g" + + def test_get_pointings_by_instrument_id(self): + """Test getting pointings filtered by instrument ID.""" + response = requests.get( + self.get_url("/pointings"), + params={"instrument": 1}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings using instrument 1 (Test Optical Telescope) + for pointing in data: + assert pointing.get("instrumentid") == 1 + + def test_get_pointings_by_instrument_name(self): + """Test getting pointings filtered by instrument name.""" + response = requests.get( + self.get_url("/pointings"), + params={"instrument": "Test Optical Telescope"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings using the named instrument + + def test_get_pointings_by_user_id(self): + """Test getting pointings filtered by user ID.""" + response = requests.get( + self.get_url("/pointings"), + params={"user": 1}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings submitted by user 1 (admin) + for pointing in data: + assert pointing.get("submitterid") == 1 + + def test_get_pointings_by_username(self): + """Test getting pointings filtered by username.""" + response = requests.get( + self.get_url("/pointings"), + params={"user": "admin"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings submitted by admin user + + def test_get_pointings_by_time_range(self): + """Test getting pointings filtered by time range.""" + response = requests.get( + self.get_url("/pointings"), + params={ + "completed_after": "2019-04-25T08:00:00.000000", + "completed_before": "2019-04-25T15:00:00.000000", + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings completed within this time range + + def test_get_pointings_by_depth_range(self): + """Test getting pointings filtered by depth range.""" + response = requests.get( + self.get_url("/pointings"), + params={ + "depth_gt": 19.0, # Greater than 19 mag + "depth_lt": 22.0, # Less than 22 mag + "depth_unit": "ab_mag", + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should find pointings within this depth range + for pointing in data: + depth = pointing.get("depth") + if depth is not None: + assert 19.0 < depth < 22.0 + + def test_post_single_pointing(self): + """Test posting a single pointing.""" + + pointing_data = { + "graceid": "S190425z", + "pointing": { + "ra": 130.456, + "dec": -15.678, + "instrumentid": 1, + "depth": 22.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T12:00:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + assert len(data["pointing_ids"]) == 1 + assert isinstance(data["pointing_ids"][0], int) + # Should have no errors + assert len(data.get("ERRORS", [])) == 0 + + def test_post_multiple_pointings(self): + """Test posting multiple pointings.""" + + pointing_data = { + "graceid": "S190425z", + "pointings": [ + { + "ra": 135.123, + "dec": -20.456, + "instrumentid": 1, + "depth": 22.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T12:30:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + { + "ra": 140.789, + "dec": -25.123, + "instrumentid": 2, + "depth": 21.0, + "depth_unit": "ab_mag", + "time": "2019-04-25T13:00:00.000000", + "status": "completed", + "pos_angle": 45.0, + "band": "R", + }, + ], + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + assert len(data["pointing_ids"]) == 2 + # Should have no errors + assert len(data.get("ERRORS", [])) == 0 + + def test_post_planned_pointing(self): + """Test posting a planned pointing.""" + + pointing_data = { + "graceid": "GW190521", # Known test graceid + "pointing": { + "ra": 145.123, + "dec": -30.456, + "instrumentid": 1, + "depth": 23.5, + "depth_unit": "ab_mag", + "time": "2020-05-21T18:00:00.000000", + "status": "planned", + "band": "I", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + assert len(data["pointing_ids"]) == 1 + + # Store the planned pointing ID for later tests + self.planned_pointing_id = data["pointing_ids"][0] + + def test_post_pointing_with_doi_request(self): + """Test posting a pointing with DOI request.""" + + pointing_data = { + "graceid": "S190425z", + "pointing": { + "ra": 160.789, + "dec": 30.123, + "instrumentid": 1, + "depth": 21.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T14:00:00.000000", + "status": "completed", + "pos_angle": 90.0, + "band": "g", + }, + "request_doi": True, + "creators": [{"name": "Test Author", "affiliation": "Test Institution"}], + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + # Note: DOI creation might not be configured in test environment + # assert "DOI" in data + + def test_post_pointing_update_planned(self): + """Test updating a planned pointing to completed.""" + # First, make sure we have a planned pointing + params = {"id": 8, "status": "planned"} # Known planned pointing from test data + response = requests.get( + self.get_url("/pointings"), + json=params, + headers={"api_token": self.admin_token}, + ) + + if response.status_code == status.HTTP_200_OK and len(response.json()) > 0: + # Now update it to completed + update_data = { + "graceid": "GW190521", + "pointing": { + "id": 8, + "time": "2020-05-22T08:30:00.000000", + "pos_angle": 45.0, + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + assert len(data["pointing_ids"]) == 1 + + def test_post_pointing_invalid_graceid(self): + """Test posting pointing with invalid graceid.""" + + pointing_data = { + "graceid": "INVALID123", + "pointing": { + "ra": 123.456, + "dec": -45.678, + "instrumentid": 1, + "depth": 22.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T12:00:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid graceid" in response.json().get("message") + + def test_post_pointing_missing_required_fields(self): + """Test posting pointing with missing required fields.""" + + pointing_data = { + "graceid": "S190425z", + "pointing": { + "ra": 123.456, + "dec": -45.678, + # Missing instrumentid and band (required for completed observations) + "depth": 22.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T12:00:00.000000", + "status": "completed", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + # With improved Pydantic validation, we now get proper HTTP 400 status codes + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert "validation error" in data.get("message", "").lower() + # Should indicate missing required field (band is caught first by Pydantic) + assert "band is required" in str(data) + + def test_update_pointing_cancel(self): + """Test updating pointing status to cancelled.""" + # First create a planned pointing + + pointing_data = { + "graceid": "S190425z", + "pointing": { + "ra": 165.000, + "dec": 35.000, + "instrumentid": 1, + "depth": 24.5, + "depth_unit": "ab_mag", + "time": "2019-04-25T17:00:00.000000", + "status": "planned", + "band": "u", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + pointing_id = response.json()["pointing_ids"][0] + + # Now cancel it + update_data = {"status": "cancelled", "ids": [pointing_id]} + + response = requests.post( + self.get_url("/update_pointings"), + json=update_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "Updated" in data["message"] + assert "1" in data["message"] # Should update 1 pointing + + def test_cancel_all_pointings(self): + """Test cancelling all pointings for a graceid and instrument.""" + # First create some planned pointings + + for i in range(3): + pointing_data = { + "graceid": "S190425z", # Known test graceid + "pointing": { + "ra": 10.0 + i, + "dec": 5.0 + i, + "instrumentid": 3, # Use Mock Radio Dish + "depth": 24.0, + "depth_unit": "ab_mag", + "time": f"2019-04-25T18:{i:02d}:00.000000", + "status": "planned", + "band": "V", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + + # Now cancel all for this graceid and instrument + cancel_data = {"graceid": "S190425z", "instrumentid": 3} + + response = requests.post( + self.get_url("/cancel_all"), + json=cancel_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "Updated" in data["message"] + assert "3" in data["message"] # Should cancel 3 pointings + + def test_request_doi_for_pointings(self): + """Test requesting DOI for existing pointings.""" + # First create some completed pointings + + pointing_ids = [] + + for i in range(2): + pointing_data = { + "graceid": "S190425z", + "pointing": { + "ra": 170.0 + i, + "dec": 40.0 + i, + "instrumentid": 1, + "depth": 23.0, + "depth_unit": "ab_mag", + "time": f"2019-04-25T19:{i:02d}:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + } + + response = requests.post( + self.get_url("/pointings"), + json=pointing_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + pointing_ids.extend(response.json()["pointing_ids"]) + + # Now request DOI + doi_data = { + "graceid": "S190425z", + "ids": pointing_ids, + "request_doi": True, + "creators": [{"name": "Test Researcher", "affiliation": "Test University"}], + } + + # Now create new pointing with DOI request for all created pointings + doi_data = { + "graceid": "S190425z", + # Need a pointing or pointings parameter + "pointing": { + "ra": 175.0, + "dec": 45.0, + "instrumentid": 1, + "depth": 23.0, + "depth_unit": "ab_mag", + "time": "2019-04-25T20:00:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + # DOI-related parameters + "request_doi": True, + "creators": [{"name": "Test Researcher", "affiliation": "Test University"}], + } + + response = requests.post( + self.get_url("/pointings"), + json=doi_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "DOI" in data + + def test_pointing_unauthorized_access(self): + """Test that unauthorized requests are rejected.""" + url = self.get_url("/pointings") + + # Request without API token + response = requests.get(url) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + # Request with invalid API token + invalid_headers = {"api_token": "invalid_token"} + response = requests.get(url, headers=invalid_headers) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_pointings_with_existing_api_tokens(self): + """Test with different valid API tokens from test data.""" + url = self.get_url("/pointings") + + # Test with admin token + response = requests.get(url, headers={"api_token": self.admin_token}) + assert response.status_code == status.HTTP_200_OK + + # Test with scientist token + response = requests.get(url, headers={"api_token": self.scientist_token}) + assert response.status_code == status.HTTP_200_OK + + # Test with regular user token + response = requests.get(url, headers={"api_token": self.user_token}) + assert response.status_code == status.HTTP_200_OK + + def test_get_pointings_by_specific_coordinates(self): + """Test getting pointings near specific coordinates from test data.""" + url = self.get_url("/pointings") + + # Test data has pointings around these coordinates + params = {"ids": "[1]"} # pointing 1 is at (123.456, -12.345) + response = requests.get( + url, json=params, headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) > 0 + # Check that we get back the expected position + assert "123.456" in data[0]["position"] + assert "-12.345" in data[0]["position"] + + +# Additional test class for testing with specific test data values +class TestPointingWithSpecificData: + """Test pointing functionality using specific values from test data.""" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + TEST_USER_API_TOKEN = "test_token_user_002" + + @classmethod + def setup_class(cls): + cls.session = requests.Session() + cls.headers = { + "Content-Type": "application/json", + "api_token": cls.TEST_USER_API_TOKEN, + } + + def test_get_specific_pointing_by_id(self): + """Test getting pointing 1 specifically.""" + url = self.get_url("/pointings") + params = {"id": 1} + response = self.session.get(url, params=params, headers=self.headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 1 + pointing = data[0] + + # Verify specific values from test data + assert pointing["id"] == 1 + assert pointing["status"] == "completed" + assert pointing["instrumentid"] == 1 + assert pointing["depth"] == 20.5 + assert pointing["depth_err"] == 0.1 + assert pointing["depth_unit"] == "ab_mag" + assert pointing["band"] == "r" # band enum 11 = r + assert "123.456" in pointing["position"] + assert "-12.345" in pointing["position"] + + def test_get_specific_pointing_by_graceid_and_instrument(self): + """Test getting pointings for S190425z with instrument 1.""" + url = self.get_url("/pointings") + params = {"graceid": "S190425z", "instrument": "1"} + response = self.session.get(url, params=params, headers=self.headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Should find pointings 1 and 2 which are linked to S190425z + assert len(data) >= 1 + + def test_create_pointing_for_existing_graceid(self): + """Test creating a pointing for an existing graceid.""" + url = self.get_url("/pointings") + + pointing_data = { + "graceid": "MS230101a", # Test graceid from test data + "pointing": { + "ra": 180.0, + "dec": 0.0, + "instrumentid": 1, # Test Optical Telescope + "depth": 24.0, + "depth_unit": "ab_mag", + "time": "2023-01-01T12:00:00.000000", + "status": "completed", + "pos_angle": 0.0, + "band": "V", + }, + } + + response = self.session.post(url, json=pointing_data, headers=self.headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "pointing_ids" in data + assert len(data["pointing_ids"]) == 1 + assert len(data.get("ERRORS", [])) == 0 + + @classmethod + def teardown_class(cls): + cls.session.close() diff --git a/tests/fastapi/test_query_alert.py b/tests/fastapi/test_query_alert.py new file mode 100644 index 00000000..b957d45a --- /dev/null +++ b/tests/fastapi/test_query_alert.py @@ -0,0 +1,533 @@ +""" +Test event endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +from datetime import datetime +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_V1_PREFIX = "/api/v1" + + +class TestEventEndpoints: + """Test class for event-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + def test_query_alerts_no_params(self): + """Test querying alerts without any parameters.""" + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Should return alerts that exist in test data + assert len(data) > 0 + for alert in data: + assert "id" in alert + assert "graceid" in alert + + def test_query_alerts_by_graceid(self): + """Test querying alerts filtered by graceid.""" + response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) > 0 + # All returned alerts should have the specified graceid + for alert in data: + assert alert["graceid"] == "S190425z" + + def test_query_alerts_by_alert_type(self): + """Test querying alerts filtered by alert type.""" + response = requests.get( + self.get_url("/query_alerts"), + params={"alert_type": "Initial"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned alerts should have the specified alert_type + for alert in data: + assert alert["alert_type"] == "Initial" + + def test_query_alerts_by_graceid_and_alert_type(self): + """Test querying alerts filtered by both graceid and alert type.""" + response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": "S190425z", "alert_type": "Initial"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # All returned alerts should match both filters + for alert in data: + assert alert["graceid"] == "S190425z" + assert alert["alert_type"] == "Initial" + + def test_post_alert(self): + """Test posting a new alert (admin only).""" + alert_data = { + "graceid": "TEST123", + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + "far": 1.5e-8, + "group": "CBC", + "timesent": "2025-05-01T12:00:00.000Z", + "time_of_signal": "2025-05-01T11:58:20.000Z", + "description": "Test event for API testing", + "distance": 200.0, + "distance_error": 50.0, + "prob_bns": 0.8, + "prob_nsbh": 0.15, + "prob_bbh": 0.03, + "prob_terrestrial": 0.02, + } + + response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["graceid"] == alert_data["graceid"] + assert data["alert_type"] == alert_data["alert_type"] + assert data["prob_bns"] == alert_data["prob_bns"] + + def test_post_alert_unauthorized(self): + """Test that non-admin users cannot post alerts.""" + alert_data = { + "graceid": "TEST456", + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + } + + response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.user_token}, + ) + + assert response.status_code == 403 + assert "admin" in response.json()["message"].lower() + + def test_get_gw_skymap(self): + """Test getting a skymap FITS file.""" + response = requests.get( + self.get_url("/gw_skymap"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + # Even if the file doesn't exist in test data, we should get a valid response + assert response.status_code in [200, 404] + if response.status_code == status.HTTP_200_OK: + assert response.headers["Content-Type"] == "application/fits" + else: + assert "Error in retrieving skymap file" in response.json()["message"] + + def test_get_gw_contour(self): + """Test getting alert contour data.""" + response = requests.get( + self.get_url("/gw_contour"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + # Even if the file doesn't exist in test data, we should get a valid response + assert response.status_code in [200, 404] + if response.status_code == status.HTTP_200_OK: + assert response.headers["Content-Type"] == "application/json" + else: + assert "Error in retrieving Contour file" in response.json()["message"] + + def test_get_grb_moc_file(self): + """Test getting GRB MOC file.""" + response = requests.get( + self.get_url("/grb_moc_file"), + params={"graceid": "S190425z", "instrument": "gbm"}, + headers={"api_token": self.admin_token}, + ) + + # Even if the file doesn't exist in test data, we should get a valid response + assert response.status_code in [200, 404] + if response.status_code == status.HTTP_200_OK: + assert response.headers["Content-Type"] == "application/json" + else: + assert "MOC file" in response.json()["message"] + + def test_get_grb_moc_file_invalid_instrument(self): + """Test getting GRB MOC file with invalid instrument.""" + response = requests.get( + self.get_url("/grb_moc_file"), + params={"graceid": "S190425z", "instrument": "invalid"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Valid instruments are" in response.json()["message"] + + def test_del_test_alerts(self): + """Test deleting test alerts (admin only).""" + # First create a test alert that should be deleted + alert_data = { + "graceid": f"MS{datetime.now().strftime('%y%m%d')}test", + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + } + + # Create the test alert + create_response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + + # Now try to delete test alerts + response = requests.post( + self.get_url("/del_test_alerts"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_200_OK + assert "Successfully deleted test alerts" in response.json()["message"] + + def test_del_test_alerts_unauthorized(self): + """Test that non-admin users cannot delete test alerts.""" + response = requests.post( + self.get_url("/del_test_alerts"), headers={"api_token": self.user_token} + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "admin" in response.json()["message"].lower() + + def test_event_api_unauthorized_access(self): + """Test that unauthorized requests are rejected.""" + # Request without API token + response = requests.get(self.get_url("/query_alerts")) + assert response.status_code == 401 + + # Request with invalid API token + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": self.invalid_token} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_event_api_with_different_tokens(self): + """Test access with different valid API tokens.""" + # All authenticated users should be able to query alerts + for token in [self.admin_token, self.user_token, self.scientist_token]: + response = requests.get( + self.get_url("/query_alerts"), headers={"api_token": token} + ) + assert response.status_code == status.HTTP_200_OK + + +class TestEventAPIValidation: + """Test validation of event API endpoints.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_post_alert_missing_required_fields(self): + """Test creating alert with missing required fields.""" + incomplete_alert = { + "graceid": "TEST789" + # Missing alert_type, role, etc. + } + + response = requests.post( + self.get_url("/post_alert"), + json=incomplete_alert, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + error_detail = response.json()["errors"] + missing_fields = [field["params"]["field"] for field in error_detail] + assert "alert_type" in str(missing_fields) + assert "role" in str(missing_fields) + + def test_post_alert_invalid_values(self): + """Test creating alert with invalid field values.""" + invalid_alert = { + "graceid": "TEST999", + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + "prob_bns": 1.5, # Invalid probability > 1 + "prob_nsbh": -0.2, # Invalid negative probability + } + + response = requests.post( + self.get_url("/post_alert"), + json=invalid_alert, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + error_detail = response.json()["errors"] + assert "prob_bns" in str(error_detail) or "prob_nsbh" in str(error_detail) + + def test_get_skymap_without_graceid(self): + """Test getting skymap without providing graceid.""" + response = requests.get( + self.get_url("/gw_skymap"), headers={"api_token": self.admin_token} + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "graceid" in str(response.json()["errors"][0]) + + def test_get_grb_moc_without_params(self): + """Test getting GRB MOC file without required parameters.""" + # Missing instrument + response = requests.get( + self.get_url("/grb_moc_file"), + params={"graceid": "S190425z"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "instrument" in str(response.json()["errors"][0]) + + # Missing graceid + response = requests.get( + self.get_url("/grb_moc_file"), + params={"instrument": "gbm"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "graceid" in str(response.json()["errors"][0]) + + +class TestEventAPIIntegration: + """Integration tests for event API endpoints.""" + + admin_token = "test_token_admin_001" + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{API_V1_PREFIX}{endpoint}" + + def test_create_query_alert_workflow(self): + """Test complete workflow: create, query, and fetch data for an alert.""" + # Step 1: Create new alert + event_time = datetime.now().isoformat() + unique_id = f"TEST{datetime.now().strftime('%Y%m%d%H%M%S')}" + + alert_data = { + "graceid": unique_id, + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + "far": 2.5e-8, + "group": "CBC", + "timesent": event_time, + "time_of_signal": event_time, + "description": "Integration test event", + "distance": 150.0, + "distance_error": 30.0, + "prob_bns": 0.75, + "prob_nsbh": 0.15, + "prob_bbh": 0.05, + "prob_terrestrial": 0.05, + } + + create_response = requests.post( + self.get_url("/post_alert"), + json=alert_data, + headers={"api_token": self.admin_token}, + ) + assert create_response.status_code == status.HTTP_200_OK + created_alert = create_response.json() + assert created_alert["graceid"] == unique_id + + # Step 2: Query the created alert + query_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": unique_id}, + headers={"api_token": self.admin_token}, + ) + assert query_response.status_code == status.HTTP_200_OK + queried_alerts = query_response.json() + assert len(queried_alerts) == 1 + assert queried_alerts[0]["graceid"] == unique_id + assert queried_alerts[0]["prob_bns"] == alert_data["prob_bns"] + + # Step 3: Clean up - delete the test alert using del_test_alerts + delete_response = requests.post( + self.get_url("/del_test_alerts"), headers={"api_token": self.admin_token} + ) + assert delete_response.status_code == status.HTTP_200_OK + + # Step 4: Verify deletion + verify_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": unique_id}, + headers={"api_token": self.admin_token}, + ) + assert verify_response.status_code == status.HTTP_200_OK + assert len(verify_response.json()) == 0 # Should be empty + + def test_multiple_alerts_same_event(self): + """Test creating and querying multiple alerts for the same event.""" + # Create base event ID + event_id = f"TEST{datetime.now().strftime('%Y%m%d%H%M%S')}_MULTI" + + # Create initial alert + initial_alert = { + "graceid": event_id, + "alert_type": "Initial", + "role": "test", + "observing_run": "O4", + "description": "Multi-alert test - Initial", + } + response = requests.post( + self.get_url("/post_alert"), + json=initial_alert, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + + # Create update alert for same event + update_alert = { + "graceid": event_id, + "alert_type": "Update", + "role": "test", + "observing_run": "O4", + "description": "Multi-alert test - Update", + } + response = requests.post( + self.get_url("/post_alert"), + json=update_alert, + headers={"api_token": self.admin_token}, + ) + assert response.status_code == status.HTTP_200_OK + + # Query all alerts for this event + query_response = requests.get( + self.get_url("/query_alerts"), + params={"graceid": event_id}, + headers={"api_token": self.admin_token}, + ) + assert query_response.status_code == status.HTTP_200_OK + alerts = query_response.json() + assert len(alerts) == 2 + + # Verify we have both alert types + alert_types = [alert["alert_type"] for alert in alerts] + assert "Initial" in alert_types + assert "Update" in alert_types + + # Clean up + requests.post( + self.get_url("/del_test_alerts"), headers={"api_token": self.admin_token} + ) + + +class TestEventSpecificData: + """Test event endpoints with specific test data values.""" + + BASE_URL = f"{API_BASE_URL}{API_V1_PREFIX}" + admin_token = "test_token_admin_001" + + def test_known_event_s190425z(self): + """Test querying the known S190425z event from test data.""" + headers = {"api_token": self.admin_token} + + response = requests.get( + f"{self.BASE_URL}/query_alerts", + params={"graceid": "S190425z"}, + headers=headers, + ) + + assert response.status_code == status.HTTP_200_OK + alerts = response.json() + assert len(alerts) > 0 + + # Check expected fields in first alert + alert = alerts[0] + assert alert["graceid"] == "S190425z" + assert "alert_type" in alert + assert "far" in alert + assert "time_of_signal" in alert + + # Verify classification info is present + has_probs = any( + key in alert + for key in ["prob_bns", "prob_nsbh", "prob_bbh", "prob_terrestrial"] + ) + assert has_probs + + def test_query_by_alert_properties(self): + """Test querying events with BNS classification.""" + headers = {"api_token": self.admin_token} + + # First get all alerts + response = requests.get(f"{self.BASE_URL}/query_alerts", headers=headers) + + assert response.status_code == status.HTTP_200_OK + all_alerts = response.json() + + # Filter for BNS candidates (prob_bns > 0.9) + bns_candidates = [ + alert + for alert in all_alerts + if alert.get("prob_bns") is not None and alert.get("prob_bns") > 0.9 + ] + + if bns_candidates: + # Test querying one specific BNS candidate + sample_bns = bns_candidates[0] + response = requests.get( + f"{self.BASE_URL}/query_alerts", + params={"graceid": sample_bns["graceid"]}, + headers=headers, + ) + + assert response.status_code == status.HTTP_200_OK + result = response.json() + assert len(result) > 0 + assert result[0]["graceid"] == sample_bns["graceid"] + assert result[0]["prob_bns"] > 0.9 + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/fastapi/test_ui.py b/tests/fastapi/test_ui.py new file mode 100644 index 00000000..d183bf4f --- /dev/null +++ b/tests/fastapi/test_ui.py @@ -0,0 +1,291 @@ +""" +Test UI-related endpoints with real requests to the FastAPI application. +Tests use specific data from test-data.sql. +""" + +import os +import requests +import json +import pytest +from fastapi import status + +# Test configuration +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") + + +class TestUIEndpoints: + """Test class for UI-related API endpoints.""" + + # Test API tokens from test data + admin_token = "test_token_admin_001" + user_token = "test_token_user_002" + scientist_token = "test_token_sci_003" + invalid_token = "invalid_token_123" + + # Known GraceIDs from test data + KNOWN_GRACEIDS = ["S190425z", "S190426c", "MS230101a", "GW190521", "MS190425a"] + + def get_url(self, endpoint): + """Get full URL for an endpoint.""" + return f"{API_BASE_URL}{endpoint}" + + def test_ajax_alertinstruments_footprints(self): + """Test getting alert instrument footprints.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/ajax_alertinstruments_footprints"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + if response.status_code == status.HTTP_200_OK: + data = response.json() + assert isinstance(data, list) + + # If data is returned, it should have the expected structure + if len(data) > 0: + overlay = data[0] + assert "display" in overlay + assert "name" in overlay + assert "color" in overlay + assert "contours" in overlay + return # Found valid data, test passes + + # If we get here, all GraceIDs failed - this might be valid if test data doesn't have footprints + pytest.skip("No alert instrument footprints found in test data") + + def test_ajax_preview_footprint_circle(self): + """Test previewing a circular footprint.""" + response = requests.get( + self.get_url("/ajax_preview_footprint"), + params={"ra": 123.456, "dec": -12.345, "radius": 0.5, "shape": "circle"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + # The response should be a JSON string containing plotly figure data + assert isinstance(response.text, str) + # Try parsing as JSON to confirm it's valid + try: + json_data = json.loads(response.text) + assert "data" in json_data + except json.JSONDecodeError: + assert False, "Response is not valid JSON" + + def test_ajax_preview_footprint_rectangle(self): + """Test previewing a rectangular footprint.""" + response = requests.get( + self.get_url("/ajax_preview_footprint"), + params={ + "ra": 123.456, + "dec": -12.345, + "height": 0.5, + "width": 1.0, + "shape": "rectangle", + }, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + # The response should be a JSON string containing plotly figure data + assert isinstance(response.text, str) + # Try parsing as JSON to confirm it's valid + try: + json_data = json.loads(response.text) + assert "data" in json_data + except json.JSONDecodeError: + assert False, "Response is not valid JSON" + + def test_ajax_preview_footprint_invalid_shape(self): + """Test previewing a footprint with invalid shape.""" + response = requests.get( + self.get_url("/ajax_preview_footprint"), + params={"ra": 123.456, "dec": -12.345, "shape": "invalid"}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "error" in data + assert "Invalid shape type" in data["error"] + + def test_ajax_coverage_calculator(self): + """Test the coverage calculator endpoint.""" + for graceid in self.KNOWN_GRACEIDS: + data = { + "graceid": graceid, + "inst_cov": "1,2,3", + "band_cov": "r,g,i", + "depth_cov": "20.0", + "depth_unit": "ab_mag", + "approx_cov": 1, + } + + response = requests.post( + self.get_url("/ajax_coverage_calculator"), + json=data, + headers={"api_token": self.admin_token}, + ) + + if response.status_code == status.HTTP_200_OK: + result = response.json() + assert "plot_html" in result + assert isinstance(result["plot_html"], str) + assert " 0 + ): + alert = alerts_response.json()[0] + alert_id = alert["id"] + + # Now get galaxies for this alert + response = requests.get( + self.get_url("/ajax_event_galaxies"), + params={"alertid": alert_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Note: data might be empty if there are no galaxies for this alert + return # Found an alert, test passes + + # If we get here, no alerts were found + pytest.skip("No alerts found in test data") + + def test_ajax_candidate(self): + """Test getting candidates by graceid.""" + for graceid in self.KNOWN_GRACEIDS: + response = requests.get( + self.get_url("/ajax_candidate"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + # Note: data might be empty if there are no candidates for this graceid + + def test_ajax_alerttype(self): + """Test getting event contour and alert information.""" + for graceid in self.KNOWN_GRACEIDS: + # First get alerts to find a valid alert ID and alert type + alerts_response = requests.get( + self.get_url("/api/v1/query_alerts"), + params={"graceid": graceid}, + headers={"api_token": self.admin_token}, + ) + + if ( + alerts_response.status_code == status.HTTP_200_OK + and len(alerts_response.json()) > 0 + ): + alert = alerts_response.json()[0] + alert_id = alert["id"] + alert_type = alert["alert_type"].lower() + + # Now get contour data + url_id = f"{alert_id}_{alert_type}" + response = requests.get( + self.get_url("/ajax_alerttype"), + params={"urlid": url_id}, + headers={"api_token": self.admin_token}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "hidden_alertid" in data + assert "detection_overlays" in data + return # Found an alert, test passes + + # If we get here, no alerts were found + pytest.skip("No alerts found in test data") + + def test_authentication_required(self): + """Test that authentication is required for protected endpoints.""" + endpoints = ["/ajax_coverage_calculator", "/ajax_request_doi"] + + for endpoint in endpoints: + # POST endpoints + if endpoint == "/ajax_coverage_calculator": + response = requests.post( + self.get_url(endpoint), json={"graceid": self.KNOWN_GRACEIDS[0]} + ) + else: + # GET endpoints + response = requests.get(self.get_url(endpoint)) + + assert response.status_code == 401 + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v"]) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..104cbddc --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,11 @@ +# Test requirements for GWTM FastAPI tests +# Minimum dependencies needed to run the test suite + +# Core testing framework +pytest>=7.3.1 + +# FastAPI for status codes and test utilities +fastapi>=0.103.0 + +# HTTP requests for API testing +requests>=2.31.0 \ No newline at end of file diff --git a/tests/run-fastapi-tests.sh b/tests/run-fastapi-tests.sh new file mode 100755 index 00000000..38b99100 --- /dev/null +++ b/tests/run-fastapi-tests.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Run tests for a specific module +# Test data is loaded automatically by conftest.py +set -e + +# Default value +MODULE=${1:-all} + +if [ "$MODULE" == "all" ]; then + # Run all test modules + echo "===== Running all FastAPI tests =====" + pytest tests/fastapi/ -v --disable-warnings +else + # Run a specific module + if [ -f "tests/fastapi/test_${MODULE}.py" ]; then + echo "===== Running tests for module: $MODULE =====" + pytest tests/fastapi/test_${MODULE}.py -v --disable-warnings + else + echo "Module test file not found: tests/fastapi/test_${MODULE}.py" + exit 1 + fi +fi + +echo "All tests completed successfully!" \ No newline at end of file diff --git a/tests/test-data.sql b/tests/test-data.sql new file mode 100644 index 00000000..b4fec973 --- /dev/null +++ b/tests/test-data.sql @@ -0,0 +1,206 @@ +-- Comprehensive test data for treasuremap database +-- This file contains minimal test data for development and testing +-- All tables are represented with essential test data + +-- Disable foreign key checks temporarily +SET session_replication_role = replica; + +-- Clear existing data in dependency order +TRUNCATE TABLE public.pointing_event CASCADE; +TRUNCATE TABLE public.pointing CASCADE; +TRUNCATE TABLE public.footprint_ccd CASCADE; +TRUNCATE TABLE public.instrument CASCADE; +TRUNCATE TABLE public.gw_candidate CASCADE; +TRUNCATE TABLE public.icecube_notice_coinc_event CASCADE; +TRUNCATE TABLE public.icecube_notice CASCADE; +TRUNCATE TABLE public.gw_galaxy_score CASCADE; +TRUNCATE TABLE public.gw_galaxy_entry CASCADE; +TRUNCATE TABLE public.gw_galaxy_list CASCADE; +TRUNCATE TABLE public.gw_galaxy CASCADE; +TRUNCATE TABLE public.doi_author CASCADE; +TRUNCATE TABLE public.doi_author_group CASCADE; +TRUNCATE TABLE public.useractions CASCADE; +TRUNCATE TABLE public.usergroups CASCADE; +TRUNCATE TABLE public.groups CASCADE; +TRUNCATE TABLE public.users CASCADE; +TRUNCATE TABLE public.gw_alert CASCADE; +TRUNCATE TABLE public.glade_2p3 CASCADE; + +-- Insert test users with working passwords +-- Password hashes generated with werkzeug.security.generate_password_hash() +INSERT INTO public.users (id, username, firstname, lastname, password_hash, datecreated, email, verified, api_token) +VALUES + (1, 'admin', 'Admin', 'User', 'pbkdf2:sha256:260000$RjdGk7VP$5f8a2b1d8e4c3a6f9b2e5c8b1f4a7d0e3c6a9f2b5e8d1c4a7b0e3f6c9a2d5b8e1', NOW(), 'admin@test.com', true, 'test_token_admin_001'), + (2, 'testuser', 'Test', 'User', 'pbkdf2:sha256:260000$SkfHl8WQ$6a9b3c2e5d7f0a4b8e1c5a9d2f6b0c4e7a1d5c8b2e6a0d3f7c1e5b9a3f6d0c4', NOW(), 'test@test.com', true, 'test_token_user_002'), + (3, 'scientist', 'Science', 'User', 'pbkdf2:sha256:260000$TlgIm9XR$7b0c4d3f6e8a1c5b9f2e6d0a3c7a2d5b8e1f4c7a0e3d6b9c2f5e8a1d4c7b0f3', NOW(), 'science@test.com', true, 'test_token_sci_003'); + +-- Insert test groups +INSERT INTO public.groups (id, name, datecreated) +VALUES + (1, 'admin', NOW()), + (2, 'researchers', NOW()), + (3, 'observers', NOW()); + +-- Insert user-group relationships +INSERT INTO public.usergroups (id, userid, groupid, role) +VALUES + (1, 1, 1, 'admin'), + (2, 2, 2, 'member'), + (3, 3, 2, 'member'), + (4, 3, 3, 'lead'); + +-- Insert test GW alerts +INSERT INTO public.gw_alert (id, graceid, alternateid, role, time_of_signal, timesent, datecreated, alert_type, observing_run, far, distance, distance_error, prob_bns, prob_nsbh, prob_bbh, prob_terrestrial, area_50, area_90) +VALUES + (1, 'S190425z', 'G298048', 'observation', '2019-04-25 08:18:05', '2019-04-25 08:18:26', NOW(), 'Preliminary', 'O3', 9.11e-6, 156.0, 41.0, 0.72, 0.23, 0.05, 0.00, 1131.0, 3818.0), + (2, 'S190426c', 'G298146', 'observation', '2019-04-26 15:21:55', '2019-04-26 15:22:16', NOW(), 'Initial', 'O3', 1.23e-6, 377.0, 100.0, 0.00, 0.56, 0.44, 0.00, 1033.0, 3502.0), + (3, 'MS230101a', NULL, 'test', '2023-01-01 00:00:00', '2023-01-01 00:01:00', NOW(), 'Preliminary', 'O4', 5.55e-8, 200.0, 50.0, 0.90, 0.05, 0.05, 0.00, 500.0, 1500.0), + (4, 'GW190521', 'GW190521_074359', 'observation', '2020-05-21 07:43:59', '2020-05-21 07:43:35', NOW(), 'Initial', 'O3', 2.5e-7, 5300.0, 2600.0, 0.0, 0.0, 0.95, 0.04, 0.01, 0.0), + (5, 'MS190425a', 'MS190425a-v1', 'test', '2019-04-25 15:00:00', '2019-04-25 15:00:00', NOW(), 'Test', 'O3', 1.0e-5, 100.0, 50.0, 0.5, 0.2, 0.1, 0.1, 0.7, 0.3); + +-- Insert test instruments with proper enum values +-- instrument_type: photometric, spectroscopic +INSERT INTO public.instrument (id, instrument_name, nickname, instrument_type, datecreated, submitterid) +VALUES + (1, 'Test Optical Telescope', 'TOT', 'photometric', NOW(), 1), + (2, 'Test X-ray Observatory', 'TXO', 'spectroscopic', NOW(), 1), + (3, 'Mock Radio Dish', 'MRD', 'photometric', NOW(), 2); + +-- Insert test footprint CCDs +INSERT INTO public.footprint_ccd (id, instrumentid, footprint) +VALUES + (1, 1, ST_GeomFromText('POLYGON((-1 -1, 1 -1, 1 1, -1 1, -1 -1))', 4326)), + (2, 2, ST_GeomFromText('POLYGON((-0.5 -0.5, 0.5 -0.5, 0.5 0.5, -0.5 0.5, -0.5 -0.5))', 4326)), + (3, 3, ST_GeomFromText('POLYGON((-2 -2, 2 -2, 2 2, -2 2, -2 -2))', 4326)); + +-- Insert test pointings with proper enum values +-- status: planned, completed, cancelled +-- depth_unit: ab_mag, vega_mag, flux_erg, flux_jy +-- band: U, B, V, R, I, J, H, K, u, g, r, i, z, etc. +INSERT INTO public.pointing (id, status, position, instrumentid, depth, depth_err, depth_unit, time, datecreated, dateupdated, submitterid, pos_angle, band, central_wave, bandwidth) +VALUES + (1, 'completed', ST_GeomFromText('POINT(123.456 -12.345)', 4326), 1, 20.5, 0.1, 'ab_mag', '2019-04-25 09:00:00', NOW(), NULL, 1, 0.0, 'r', 6415.0, 1487.0), + (2, 'planned', ST_GeomFromText('POINT(234.567 -23.456)', 4326), 2, 21.0, 0.2, 'ab_mag', '2019-04-25 10:00:00', NOW(), NULL, 2, 45.0, 'g', 4730.0, 1503.0), + (3, 'completed', ST_GeomFromText('POINT(345.678 34.567)', 4326), 3, 19.8, 0.05, 'ab_mag', '2019-04-26 16:00:00', NOW(), NULL, 3, 90.0, 'i', 7836.0, 1468.0), + (4, 'cancelled', ST_GeomFromText('POINT(456.789 45.678)', 4326), 1, 22.0, 0.1, 'ab_mag', '2019-04-26 18:00:00', NOW(), NOW(), 1, 0.0, 'r', 6415.0, 1487.0), + (5, 'completed', ST_GeomFromText('POINT(150.0 -30.0)', 4326), 1, 20.5, 0.1, 'ab_mag', '2019-04-25 12:00:00', NOW(), NULL, 1, 45.0, 'V', 5338.0, 810.0), + (6, 'completed', ST_GeomFromText('POINT(151.0 -30.0)', 4326), 1, 21.0, 0.1, 'ab_mag', '2019-04-25 12:30:00', NOW(), NULL, 1, 45.0, 'r', 6415.0, 1487.0), + (7, 'completed', ST_GeomFromText('POINT(152.0 -30.0)', 4326), 1, 20.8, 0.1, 'ab_mag', '2019-04-25 13:00:00', NOW(), NULL, 1, 45.0, 'R', 6311.0, 1220.0), + -- Planned pointings for GW190521 + (8, 'planned', ST_GeomFromText('POINT(134.0 35.0)', 4326), 1, 21.5, NULL, 'ab_mag', '2020-05-22 08:00:00', NOW(), NULL, 1, NULL, 'V', 5338.0, 810.0), + (9, 'planned', ST_GeomFromText('POINT(135.0 35.0)', 4326), 1, 21.5, NULL, 'ab_mag', '2020-05-22 09:00:00', NOW(), NULL, 1, NULL, 'r', 6415.0, 1487.0), + -- Cancelled pointing for MS190425a + (10, 'cancelled', ST_GeomFromText('POINT(0.0 0.0)', 4326), 1, 20.0, NULL, 'ab_mag', '2019-04-25 16:00:00', NOW(), NOW(), 1, NULL, 'V', 5338.0, 810.0); + +-- Insert pointing events (link pointings to GW events) +INSERT INTO public.pointing_event (id, pointingid, graceid) +VALUES + (1, 1, 'S190425z'), + (2, 2, 'S190425z'), + (3, 3, 'S190426c'), + (4, 4, 'S190426c'), + (5, 5, 'S190425z'), + (6, 6, 'S190425z'), + (7, 7, 'S190425z'), + (8, 8, 'GW190521'), + (9, 9, 'GW190521'), + (10, 10, 'MS190425a'); + +-- Insert test GLADE galaxies +INSERT INTO public.glade_2p3 (id, pgc_number, position, gwgc_name, _2mass_name, hyperleda_name, sdssdr12_name, distance, distance_error, redshift, bmag, bmag_err) +VALUES + (1, 1234567, ST_GeomFromText('POINT(120.0 -10.0)', 4326), 'GWGC_TEST_1', '2MASS_J08000000-1000000', 'HyperLEDA_1', 'SDSS_J120000.00-100000.0', 45.2, 2.1, 0.033, 12.5, 0.1), + (2, 2345678, ST_GeomFromText('POINT(230.0 -20.0)', 4326), 'GWGC_TEST_2', '2MASS_J15200000-2000000', 'HyperLEDA_2', 'SDSS_J230000.00-200000.0', 156.8, 5.5, 0.115, 14.2, 0.2), + (3, 3456789, ST_GeomFromText('POINT(340.0 30.0)', 4326), 'GWGC_TEST_3', '2MASS_J22400000+3000000', 'HyperLEDA_3', 'SDSS_J340000.00+300000.0', 89.3, 3.2, 0.065, 13.1, 0.15); + +-- Insert test GW galaxy mappings +INSERT INTO public.gw_galaxy (id, graceid, galaxy_catalog, galaxy_catalogid, reference) +VALUES + (1, 'S190425z', 1, 1, 'GLADE v2.3'), + (2, 'S190425z', 1, 2, 'GLADE v2.3'); + +-- Insert test galaxy scores with proper enum values +-- score_type: default +INSERT INTO public.gw_galaxy_score (id, gw_galaxyid, score_type, score) +VALUES + (1, 1, 'default', 0.85), + (2, 2, 'default', 0.72); + +-- Insert DOI author groups +INSERT INTO public.doi_author_group (id, userid, name) +VALUES + (1, 1, 'LIGO-Virgo Collaboration'), + (2, 2, 'Test Observatory Team'), + (3, 1, 'Test Observatory Group'); + +-- Insert DOI authors +INSERT INTO public.doi_author (id, name, affiliation, orcid, gnd, pos_order, author_groupid) +VALUES + (1, 'Admin User', 'Test University', '0000-0000-0000-0001', NULL, 1, 1), + (2, 'Science User', 'Test Observatory', '0000-0000-0000-0002', NULL, 2, 1), + (3, 'Test User', 'Test Institute', NULL, NULL, 1, 2); + +-- Insert test galaxy lists +INSERT INTO public.gw_galaxy_list (id, graceid, groupname, submitterid, reference, alertid, doi_url, doi_id) +VALUES + (1, 'S190425z', 'Test Group', 2, 'arXiv:2019.12345', '1', NULL, NULL), + (2, 'S190426c', 'Science Team', 3, 'ApJ 2020 000 000', '2', NULL, NULL); + +-- Insert test galaxy entries +INSERT INTO public.gw_galaxy_entry (id, listid, name, score, position, rank, info) +VALUES + (1, 1, 'NGC1234', 0.95, ST_GeomFromText('POINT(121.0 -11.0)', 4326), 1, '{"distance": 45.5, "type": "spiral"}'), + (2, 1, 'NGC5678', 0.87, ST_GeomFromText('POINT(125.0 -15.0)', 4326), 2, '{"distance": 52.1, "type": "elliptical"}'), + (3, 2, 'SDSS J1500', 0.73, ST_GeomFromText('POINT(231.0 -21.0)', 4326), 1, '{"distance": 158.3, "type": "dwarf"}'); + +-- Insert test IceCube notices +INSERT INTO public.icecube_notice (id, ref_id, graceid, alert_datetime, datecreated, observation_start, observation_stop, pval_generic, pval_bayesian, most_probable_direction_ra, most_probable_direction_dec, flux_sens_low, flux_sens_high, sens_energy_range_low, sens_energy_range_high) +VALUES + (1, 'ICECUBE_ASTROTRACK_123456', 'S190425z', '2019-04-25 08:25:00', NOW(), '2019-04-25 08:00:00', '2019-04-25 09:00:00', 0.05, 0.03, 123.5, -12.3, 1.2e-10, 5.5e-9, 1e3, 1e6), + (2, 'ICECUBE_ASTROTRACK_234567', 'S190426c', '2019-04-26 15:30:00', NOW(), '2019-04-26 15:00:00', '2019-04-26 16:00:00', 0.12, 0.08, 234.6, -23.4, 8.9e-11, 3.2e-9, 5e2, 5e5); + +-- Insert test IceCube notice events +INSERT INTO public.icecube_notice_coinc_event (id, icecube_notice_id, datecreated, event_dt, ra, dec, containment_probability, event_pval_generic, event_pval_bayesian, ra_uncertainty, uncertainty_shape) +VALUES + (1, 1, NOW(), 0.0, 123.5, -12.3, 0.5, 0.05, 0.03, 0.5, 'circular'), + (2, 2, NOW(), 30.0, 234.6, -23.4, 0.7, 0.12, 0.08, 0.3, 'elliptical'); + +-- Insert test GW candidates with proper enum values +INSERT INTO public.gw_candidate (id, datecreated, submitterid, graceid, candidate_name, tns_name, tns_url, position, discovery_date, discovery_magnitude, magnitude_central_wave, magnitude_bandwidth, magnitude_unit, magnitude_bandpass, associated_galaxy, associated_galaxy_redshift, associated_galaxy_distance) +VALUES + (1, NOW(), 2, 'S190425z', 'AT2019abc', '2019abc', 'https://www.wis-tns.org/object/2019abc', ST_GeomFromText('POINT(122.0 -11.5)', 4326), '2019-04-25 10:30:00', 18.5, 6415.0, 1487.0, 'ab_mag', 'r', 'NGC1234', 0.033, 45.5), + (2, NOW(), 3, 'S190426c', 'AT2019def', '2019def', 'https://www.wis-tns.org/object/2019def', ST_GeomFromText('POINT(232.0 -21.5)', 4326), '2019-04-26 17:15:00', 19.2, 4730.0, 1503.0, 'ab_mag', 'g', NULL, NULL, NULL); + +-- Insert test user actions +INSERT INTO public.useractions (id, userid, ipaddress, url, time, jsonvals, method) +VALUES + (1, 1, '127.0.0.1', '/api/v1/pointings', NOW(), '{"filter": "graceid=S190425z"}', 'GET'), + (2, 2, '192.168.1.100', '/api/v1/candidate', NOW(), '{"graceid": "S190425z"}', 'POST'), + (3, 3, '10.0.0.50', '/alerts', NOW(), NULL, 'GET'); + +-- Re-enable foreign key checks +SET session_replication_role = DEFAULT; + +-- Update sequences to avoid conflicts +SELECT setval('public.users_id_seq', 10); +SELECT setval('public.groups_id_seq', 10); +SELECT setval('public.usergroups_id_seq', 10); +SELECT setval('public.instrument_id_seq', 10); +SELECT setval('public.footprint_ccd_id_seq', 10); +SELECT setval('public.gw_alert_id_seq', 10); +SELECT setval('public.pointing_id_seq', 10); +SELECT setval('public.pointing_event_id_seq', 10); +SELECT setval('public.glade_2p3_id_seq', 10); +SELECT setval('public.gw_galaxy_id_seq', 10); +SELECT setval('public.gw_galaxy_score_id_seq', 10); +SELECT setval('public.doi_author_group_id_seq', 10); +SELECT setval('public.doi_author_id_seq', 10); +SELECT setval('public.gw_galaxy_list_id_seq', 10); +SELECT setval('public.gw_galaxy_entry_id_seq', 10); +SELECT setval('public.icecube_notice_id_seq', 10); +SELECT setval('public.icecube_notice_coinc_event_id_seq', 10); +SELECT setval('public.gw_candidate_id_seq', 10); +SELECT setval('public.useractions_id_seq', 10); + +-- Analyze tables to update statistics +ANALYZE;