11# copyright (c) 2020, Matthias Dellweg
22# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)
33
4+ import asyncio
45import base64
56import datetime
67import json
78import os
9+ import ssl
810import typing as t
911from collections import defaultdict
1012from contextlib import suppress
1113from io import BufferedReader
1214from urllib .parse import urljoin
1315
16+ import aiohttp
1417import requests
1518import urllib3
1619
@@ -174,6 +177,9 @@ def __init__(
174177 self ._safe_calls_only : bool = safe_calls_only
175178 self ._headers = headers or {}
176179 self ._verify = verify
180+ # Shall we make that a parameter?
181+ self ._ssl_context : t .Optional [t .Union [ssl .SSLContext , bool ]] = None
182+
177183 self ._auth_provider = auth_provider
178184 self ._cert = cert
179185 self ._key = key
@@ -225,6 +231,22 @@ def base_url(self) -> str:
225231 def cid (self ) -> t .Optional [str ]:
226232 return self ._headers .get ("Correlation-Id" )
227233
234+ @property
235+ def ssl_context (self ) -> t .Union [ssl .SSLContext , bool ]:
236+ if self ._ssl_context is None :
237+ if self ._verify is False :
238+ self ._ssl_context = False
239+ else :
240+ if isinstance (self ._verify , str ):
241+ self ._ssl_context = ssl .create_default_context (cafile = self ._verify )
242+ else :
243+ self ._ssl_context = ssl .create_default_context ()
244+ if self ._cert is not None :
245+ self ._ssl_context .load_cert_chain (self ._cert , self ._key )
246+ # Type inference is failing here.
247+ self ._ssl_context = t .cast (t .Union [ssl .SSLContext , bool ], self ._ssl_context )
248+ return self ._ssl_context
249+
228250 def load_api (self , refresh_cache : bool = False ) -> None :
229251 # TODO: Find a way to invalidate caches on upstream change
230252 xdg_cache_home : str = os .environ .get ("XDG_CACHE_HOME" ) or "~/.cache"
@@ -242,7 +264,7 @@ def load_api(self, refresh_cache: bool = False) -> None:
242264 self ._parse_api (data )
243265 except Exception :
244266 # Try again with a freshly downloaded version
245- data = self ._download_api ()
267+ data = asyncio . run ( self ._download_api () )
246268 self ._parse_api (data )
247269 # Write to cache as it seems to be valid
248270 os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
@@ -262,28 +284,31 @@ def _parse_api(self, data: bytes) -> None:
262284 if method in {"get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" }
263285 }
264286
265- def _download_api (self ) -> bytes :
287+ async def _download_api (self ) -> bytes :
266288 try :
267- response : requests .Response = self ._session .get (urljoin (self ._base_url , self ._doc_path ))
268- except requests .RequestException as e :
289+ connector = aiohttp .TCPConnector (ssl = self .ssl_context )
290+ async with aiohttp .ClientSession (connector = connector , headers = self ._headers ) as session :
291+ async with session .get (urljoin (self ._base_url , self ._doc_path )) as response :
292+ response .raise_for_status ()
293+ data = await response .read ()
294+ if "Correlation-Id" in response .headers :
295+ self ._set_correlation_id (response .headers ["Correlation-Id" ])
296+ except aiohttp .ClientError as e :
269297 raise OpenAPIError (str (e ))
270- response .raise_for_status ()
271- if "Correlation-ID" in response .headers :
272- self ._set_correlation_id (response .headers ["Correlation-ID" ])
273- return response .content
298+ return data
274299
275300 def _set_correlation_id (self , correlation_id : str ) -> None :
276- if "Correlation-ID " in self ._headers :
277- if self ._headers ["Correlation-ID " ] != correlation_id :
301+ if "Correlation-Id " in self ._headers :
302+ if self ._headers ["Correlation-Id " ] != correlation_id :
278303 raise OpenAPIError (
279304 _ ("Correlation ID returned from server did not match. {} != {}" ).format (
280- self ._headers ["Correlation-ID " ], correlation_id
305+ self ._headers ["Correlation-Id " ], correlation_id
281306 )
282307 )
283308 else :
284- self ._headers ["Correlation-ID " ] = correlation_id
309+ self ._headers ["Correlation-Id " ] = correlation_id
285310 # Do it for requests too...
286- self ._session .headers ["Correlation-ID " ] = correlation_id
311+ self ._session .headers ["Correlation-Id " ] = correlation_id
287312
288313 def param_spec (
289314 self , operation_id : str , param_type : str , required : bool = False
@@ -802,7 +827,7 @@ def call(
802827 self ._debug_callback (2 , f" { key } : { value } " )
803828 if response .text :
804829 self ._debug_callback (3 , f"{ response .text } " )
805- if "Correlation-ID " in response .headers :
806- self ._set_correlation_id (response .headers ["Correlation-ID " ])
830+ if "Correlation-Id " in response .headers :
831+ self ._set_correlation_id (response .headers ["Correlation-Id " ])
807832 response .raise_for_status ()
808833 return self .parse_response (method_spec , response )
0 commit comments