Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Add "Get Variable" endpoint for Execution API #43832

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class ConnectionResponse(BaseModel):
extra: str | None


class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(from_attributes=True)

key: str
val: str | None = Field(alias="value")


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""
Expand Down
11 changes: 5 additions & 6 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from __future__ import annotations

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes.connections import connection_router
from airflow.api_fastapi.execution_api.routes.health import health_router
from airflow.api_fastapi.execution_api.routes.task_instance import ti_router
from airflow.api_fastapi.execution_api.routes import connections, health, task_instance, variables

execution_api_router = AirflowRouter()
execution_api_router.include_router(connection_router)
execution_api_router.include_router(health_router)
execution_api_router.include_router(ti_router)
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instance.router, prefix="/task_instance", tags=["Task Instance"])
execution_api_router.include_router(variables.router, prefix="/variables", tags=["Variables"])
6 changes: 2 additions & 4 deletions airflow/api_fastapi/execution_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from airflow.models.connection import Connection

# TODO: Add dependency on JWT token
connection_router = AirflowRouter(
prefix="/connection",
tags=["Connection"],
router = AirflowRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Connection not found"}},
)

Expand All @@ -42,7 +40,7 @@ def get_task_token() -> datamodels.TIToken:
return datamodels.TIToken(ti_key="test_key")


@connection_router.get(
@router.get(
"/{connection_id}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/execution_api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

from airflow.api_fastapi.common.router import AirflowRouter

health_router = AirflowRouter(tags=["Health"])
router = AirflowRouter()


@health_router.get("/health")
@router.get("/health")
def health() -> dict:
return {"status": "healthy"}
9 changes: 3 additions & 6 deletions airflow/api_fastapi/execution_api/routes/task_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,13 @@
from airflow.utils.state import State

# TODO: Add dependency on JWT token
ti_router = AirflowRouter(
prefix="/task_instance",
tags=["Task Instance"],
)
router = AirflowRouter()


log = logging.getLogger(__name__)


@ti_router.patch(
@router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
Expand Down Expand Up @@ -133,7 +130,7 @@ def ti_update_state(
)


@ti_router.put(
@router.put(
"/{task_instance_id}/heartbeat",
status_code=status.HTTP_204_NO_CONTENT,
responses={
Expand Down
87 changes: 87 additions & 0 deletions airflow/api_fastapi/execution_api/routes/variables.py
kaxil marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging

from fastapi import Depends, HTTPException, status
from typing_extensions import Annotated

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}},
)

log = logging.getLogger(__name__)


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


@router.get(
"/{variable_key}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(
variable_key: str,
token: Annotated[datamodels.TIToken, Depends(get_task_token)],
) -> datamodels.VariableResponse:
"""Get an Airflow Variable."""
if not has_variable_access(variable_key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to variable {variable_key}",
},
)

try:
variable_value = Variable.get(variable_key)
except KeyError:
kaxil marked this conversation as resolved.
Show resolved Hide resolved
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Variable with key '{variable_key}' not found",
},
)

return datamodels.VariableResponse(key=variable_key, value=variable_value)


def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool:
"""Check if the task has access to the variable."""
# TODO: Placeholder for actual implementation

ti_key = token.ti_key
log.debug(
"Checking access for task instance with key '%s' to variable '%s'",
ti_key,
variable_key,
)
return True
kaxil marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_connection_get_from_db(self, client, session):
session.add(connection)
session.commit()

response = client.get("/execution/connection/test_conn")
response = client.get("/execution/connections/test_conn")

assert response.status_code == 200
assert response.json() == {
Expand All @@ -66,7 +66,7 @@ def test_connection_get_from_db(self, client, session):
{"AIRFLOW_CONN_TEST_CONN2": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}'},
)
def test_connection_get_from_env_var(self, client, session):
response = client.get("/execution/connection/test_conn2")
response = client.get("/execution/connections/test_conn2")

assert response.status_code == 200
assert response.json() == {
Expand All @@ -81,7 +81,7 @@ def test_connection_get_from_env_var(self, client, session):
}

def test_connection_get_not_found(self, client):
response = client.get("/execution/connection/non_existent_test_conn")
response = client.get("/execution/connections/non_existent_test_conn")

assert response.status_code == 404
assert response.json() == {
Expand All @@ -95,7 +95,7 @@ def test_connection_get_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.connections.has_connection_access", return_value=False
):
response = client.get("/execution/connection/test_conn")
response = client.get("/execution/connections/test_conn")

# Assert response status code and detail for access denied
assert response.status_code == 403
Expand Down
77 changes: 77 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from unittest import mock

import pytest

from airflow.models.variable import Variable

pytestmark = pytest.mark.db_test


class TestGetVariable:
def test_variable_get_from_db(self, client, session):
Variable.set(key="var1", value="value", session=session)
session.commit()

response = client.get("/execution/variables/var1")

assert response.status_code == 200
assert response.json() == {"key": "var1", "value": "value"}

# Remove connection
Variable.delete(key="var1", session=session)
session.commit()

@mock.patch.dict(
"os.environ",
{"AIRFLOW_VAR_KEY1": "VALUE"},
)
def test_variable_get_from_env_var(self, client, session):
response = client.get("/execution/variables/key1")

assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

def test_variable_get_not_found(self, client):
response = client.get("/execution/variables/non_existent_var")

assert response.status_code == 404
assert response.json() == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}

def test_variable_get_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False
):
response = client.get("/execution/variables/key1")

# Assert response status code and detail for access denied
assert response.status_code == 403
assert response.json() == {
"detail": {
"reason": "access_denied",
"message": "Task does not have access to variable key1",
}
}