Skip to content

Commit 150d727

Browse files
dabladavidblain-infrabel
authored andcommitted
Fixes compat issue HTTPX proxy configuration in KiotaRequestAdapterHook and fixed retry in MSGraphSensor (apache#45746)
--------- Co-authored-by: David Blain <[email protected]>
1 parent 560faea commit 150d727

File tree

5 files changed

+31
-13
lines changed

5 files changed

+31
-13
lines changed

providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import httpx
2929
from azure.identity import ClientSecretCredential
30-
from httpx import Timeout
30+
from httpx import AsyncHTTPTransport, Timeout
3131
from kiota_abstractions.api_error import APIError
3232
from kiota_abstractions.method import Method
3333
from kiota_abstractions.request_information import RequestInformation
@@ -208,9 +208,9 @@ def format_no_proxy_url(url: str) -> str:
208208
def to_httpx_proxies(cls, proxies: dict) -> dict:
209209
proxies = proxies.copy()
210210
if proxies.get("http"):
211-
proxies["http://"] = proxies.pop("http")
211+
proxies["http://"] = AsyncHTTPTransport(proxy=proxies.pop("http"))
212212
if proxies.get("https"):
213-
proxies["https://"] = proxies.pop("https")
213+
proxies["https://"] = AsyncHTTPTransport(proxy=proxies.pop("https"))
214214
if proxies.get("no"):
215215
for url in proxies.pop("no", "").split(","):
216216
proxies[cls.format_no_proxy_url(url.strip())] = None
@@ -288,7 +288,7 @@ def get_conn(self) -> RequestAdapter:
288288
http_client = GraphClientFactory.create_with_default_middleware(
289289
api_version=api_version, # type: ignore
290290
client=httpx.AsyncClient(
291-
proxy=httpx_proxies, # type: ignore
291+
mounts=httpx_proxies,
292292
timeout=Timeout(timeout=self.timeout),
293293
verify=verify,
294294
trust_env=trust_env,

providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def execute(self, context: Context):
129129
def retry_execute(
130130
self,
131131
context: Context,
132+
**kwargs,
132133
) -> Any:
133134
self.execute(context=context)
134135

Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}
1+
[{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "InProgress"},{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}]

providers/tests/microsoft/azure/sensors/test_msgraph.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import json
20+
from datetime import datetime
2021

2122
import pytest
2223

@@ -31,7 +32,7 @@
3132
class TestMSGraphSensor(Base):
3233
def test_execute(self):
3334
status = load_json("resources", "status.json")
34-
response = mock_json_response(200, status)
35+
response = mock_json_response(200, *status)
3536

3637
with self.patch_hook_and_request_adapter(response):
3738
sensor = MSGraphSensor(
@@ -40,6 +41,7 @@ def test_execute(self):
4041
url="myorg/admin/workspaces/scanStatus/{scanId}",
4142
path_parameters={"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
4243
result_processor=lambda context, result: result["id"],
44+
retry_delay=5,
4345
timeout=350.0,
4446
)
4547

@@ -48,16 +50,22 @@ def test_execute(self):
4850
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
4951
assert isinstance(results, str)
5052
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
51-
assert len(events) == 1
53+
assert len(events) == 3
5254
assert isinstance(events[0], TriggerEvent)
5355
assert events[0].payload["status"] == "success"
5456
assert events[0].payload["type"] == "builtins.dict"
55-
assert events[0].payload["response"] == json.dumps(status)
57+
assert events[0].payload["response"] == json.dumps(status[0])
58+
assert isinstance(events[1], TriggerEvent)
59+
assert isinstance(events[1].payload, datetime)
60+
assert isinstance(events[2], TriggerEvent)
61+
assert events[2].payload["status"] == "success"
62+
assert events[2].payload["type"] == "builtins.dict"
63+
assert events[2].payload["response"] == json.dumps(status[1])
5664

5765
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0")
5866
def test_execute_with_lambda_parameter(self):
5967
status = load_json("resources", "status.json")
60-
response = mock_json_response(200, status)
68+
response = mock_json_response(200, *status)
6169

6270
with self.patch_hook_and_request_adapter(response):
6371
sensor = MSGraphSensor(
@@ -66,6 +74,7 @@ def test_execute_with_lambda_parameter(self):
6674
url="myorg/admin/workspaces/scanStatus/{scanId}",
6775
path_parameters=lambda context, jinja_env: {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
6876
result_processor=lambda context, result: result["id"],
77+
retry_delay=5,
6978
timeout=350.0,
7079
)
7180

@@ -74,11 +83,17 @@ def test_execute_with_lambda_parameter(self):
7483
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
7584
assert isinstance(results, str)
7685
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
77-
assert len(events) == 1
86+
assert len(events) == 3
7887
assert isinstance(events[0], TriggerEvent)
7988
assert events[0].payload["status"] == "success"
8089
assert events[0].payload["type"] == "builtins.dict"
81-
assert events[0].payload["response"] == json.dumps(status)
90+
assert events[0].payload["response"] == json.dumps(status[0])
91+
assert isinstance(events[1], TriggerEvent)
92+
assert isinstance(events[1].payload, datetime)
93+
assert isinstance(events[2], TriggerEvent)
94+
assert events[2].payload["status"] == "success"
95+
assert events[2].payload["type"] == "builtins.dict"
96+
assert events[2].payload["response"] == json.dumps(status[1])
8297

8398
def test_template_fields(self):
8499
sensor = MSGraphSensor(

providers/tests/microsoft/conftest.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,10 @@ def xcom_pull(
149149
run_id: str | None = None,
150150
) -> Any:
151151
if map_indexes:
152-
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")
153-
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")
152+
return values.get(
153+
f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default
154+
)
155+
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default)
154156

155157
def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None:
156158
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value

0 commit comments

Comments
 (0)