diff --git a/.gitignore b/.gitignore index 02268f7..45507c1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ __pycache__* /.idea/ async_openai/v1* tests/private_* -!tests/ \ No newline at end of file +!tests/ +tests/v2/private_* \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e5d4b7..93005e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelogs +#### v0.0.50 (2024-02-01) + +**Breaking Changes** + +- The `OpenAI` client has been refactored to be a singleton `ProxyObject` vs a `Type` object. + + Currently, this API is accessible with `async_openai.OpenAIManager`, which provides all the existing functionality of the `OpenAI` client, with a few additional features. + + - `OpenAIManager` supports automatic proxy rotation and client selection based on available models. + + - `OpenAIManager` supports automatic retrying of failed requests, as well as enabling automatic healthchecking prior to each request to ensure the endpoint is available with `auto_healthcheck_enabled`, otherwise it will rotate to another endpoint. This is useful for ensuring high availability and reliability of the API. + + Future versions will deprecate the `OpenAI` client in favor of the `OpenAIManager` object. + +- Added new `OpenAIFunctions` class which provides a robust interface for creating and running functions. This class is also a singleton `ProxyObject`. + + This can be accessed through the `OpenAIManager.functions` object + #### v0.0.41 (2023-11-06) diff --git a/async_openai/__init__.py b/async_openai/__init__.py index c89a117..df8df99 100644 --- a/async_openai/__init__.py +++ b/async_openai/__init__.py @@ -45,7 +45,7 @@ from async_openai.routes import ApiRoutes -from async_openai.client import OpenAIClient, OpenAI +from async_openai.client import OpenAIClient, OpenAI, OpenAIManager diff --git a/async_openai/client.py b/async_openai/client.py index b61e9cc..c3b0931 100644 --- a/async_openai/client.py +++ b/async_openai/client.py @@ -1,13 +1,14 @@ import aiohttpx +import contextlib from typing import Optional, Callable, Dict, Union, List from async_openai.schemas import * from async_openai.types.options import ApiType from async_openai.utils.logs import logger -from async_openai.utils.config import get_settings, OpenAISettings, AzureOpenAISettings, OpenAIAuth +from async_openai.utils.config import get_settings, OpenAISettings, AzureOpenAISettings, OpenAIAuth, ProxyObject from async_openai.routes import ApiRoutes from async_openai.meta import OpenAIMetaClass - +from async_openai.manager import OpenAIManager as OpenAISessionManager _update_params = [ 'url', @@ -72,8 +73,42 @@ def __init__( """ Lazily Instantiates the OpenAI Client """ + self.model_rate_limits: Dict[str, Dict[str, int]] = {} + self.client_callbacks: List[Callable] = [] self.configure_params(**kwargs) + def response_event_hook(self, response: aiohttpx.Response): + """ + Monitor the rate limits + """ + url = response.url + headers = response.headers + with contextlib.suppress(Exception): + if self.is_azure: + model_name = str(url).split('deployments/', 1)[-1].split('/', 1)[0].strip() + else: + model_name = headers.get('openai-model') + model_name = model_name.lstrip("https:").strip() + if not model_name: return + if model_name not in self.model_rate_limits: + self.model_rate_limits[model_name] = {} + for key, value in { + 'x-ratelimit-remaining-requests': 'remaining', + 'x-ratelimit-remaining-tokens': 'remaining_tokens', + 'x-ratelimit-limit-tokens': 'limit_tokens', + 'x-ratelimit-limit-requests': 'limit_requests', + }.items(): + if key in headers: + self.model_rate_limits[model_name][value] = int(headers[key]) + if self.debug_enabled: + logger.info(f"Rate Limits: {self.model_rate_limits}") + + async def aresponse_event_hook(self, response: aiohttpx.Response): + """ + Monitor the rate limits + """ + return self.response_event_hook(response) + @property def client(self) -> aiohttpx.Client: """ @@ -121,6 +156,7 @@ def configure_params( is_azure: Optional[bool] = None, azure_model_mapping: Optional[Dict[str, str]] = None, auth: Optional[OpenAIAuth] = None, + client_callbacks: Optional[List[Callable]] = None, **kwargs ): # sourcery skip: low-code-quality """ @@ -233,6 +269,9 @@ def configure_params( self.log_method = logger.info if self.debug_enabled else logger.debug if not self.debug_enabled: self.settings.disable_httpx_logger() + + if client_callbacks is not None: + self.client_callbacks = client_callbacks # if self.debug_enabled: # logger.info(f"OpenAI Client Configured: {self.client.base_url}") # logger.debug(f"Debug Enabled: {self.debug_enabled}") @@ -243,13 +282,18 @@ def configure_client(self, **kwargs): """ if self._client is not None: return # logger.info(f"OpenAI Client Configured: {self.base_url} [{self.name}]") + extra_kwargs = {} + if self.settings.limit_monitor_enabled: + extra_kwargs['event_hooks'] = {'response': [self.response_event_hook]} + extra_kwargs['async_event_hooks'] = {'response': [self.aresponse_event_hook]} + self._client = aiohttpx.Client( base_url = self.base_url, timeout = self.timeout, limits = self.settings.api_client_limits, auth = self.auth, headers = self.headers, - # auth = self.settings. + **extra_kwargs, ) def configure_routes(self, **kwargs): @@ -273,6 +317,7 @@ def configure_routes(self, **kwargs): azure_model_mapping = self.azure_model_mapping, disable_retries = self.disable_retries, retry_function = self.retry_function, + client_callbacks = self.client_callbacks, **kwargs ) if self.debug_enabled: @@ -369,11 +414,42 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.async_close() + def ping(self, timeout: Optional[float] = 1.0) -> bool: + """ + Pings the API Endpoint to check if it's alive. + """ + try: + # with contextlib.suppress(Exception): + response = self.client.get('/', timeout = timeout) + data = response.json() + # we should expect a 404 with a json response + # if self.debug_enabled: logger.info(f"API Ping: {data}\n{response.headers}") + if data.get('error'): return True + except Exception as e: + logger.error(f"API Ping Failed: {e}") + return False + + async def aping(self, timeout: Optional[float] = 1.0) -> bool: + """ + Pings the API Endpoint to check if it's alive. + """ + with contextlib.suppress(Exception): + response = await self.client.async_get('/', timeout = timeout) + data = response.json() + # we should expect a 404 with a json response + if data.get('error'): return True + return False + + class OpenAI(metaclass = OpenAIMetaClass): """ - Interface for OpenAI + [V1] Interface for OpenAI + + Deprecating this class in future versions """ pass +OpenAIManager: OpenAISessionManager = ProxyObject(OpenAISessionManager) + diff --git a/async_openai/loadbalancer.py b/async_openai/loadbalancer.py new file mode 100644 index 0000000..f27f7fe --- /dev/null +++ b/async_openai/loadbalancer.py @@ -0,0 +1,248 @@ +""" +Client LoadBalancer +""" + +from __future__ import annotations + +import random +from typing import Optional, List, Dict, Union, TYPE_CHECKING + +from async_openai.schemas import * +from async_openai.utils.config import get_settings, OpenAISettings +from async_openai.utils.logs import logger + +if TYPE_CHECKING: + from async_openai.client import OpenAIClient, OpenAISessionManager + + +class ClientLoadBalancer: + """ + Manages a set of clients that can be rotated. + """ + def __init__( + self, + prioritize: Optional[str] = None, + settings: Optional[OpenAISettings] = None, + azure_model_mapping: Optional[Dict[str, str]] = None, + healthcheck: Optional[bool] = True, + manager: Optional['OpenAISessionManager'] = None, + ): + self.settings = settings or get_settings() + self.clients: Dict[str, 'OpenAIClient'] = {} + self.rotate_index: int = 0 + self.rotate_client_names: List[str] = [] + self.azure_model_mapping: Dict[str, str] = azure_model_mapping + self.healthcheck: bool = healthcheck + self.manager: Optional['OpenAISessionManager'] = manager + + assert prioritize in [None, 'azure', 'openai'], f'Invalid `prioritize` value: {prioritize}' + self.prioritize: Optional[str] = prioritize + + @property + def client_names(self) -> List[str]: + """ + Returns the list of client names. + """ + return list(self.clients.keys()) + + def run_client_init(self): + """ + Initializes the Client. + + Can be subclassed to provide custom initialization. + """ + self.init_api_client() + if self.settings.has_valid_azure: + self.init_api_client(client_name = 'az', is_azure = True, set_as_default = self.prioritize == 'azure', set_as_current = self.prioritize == 'azure') + + + @property + def api(self) -> 'OpenAIClient': + """ + Returns the inherited OpenAI client. + """ + if not self.clients: + self.run_client_init() + if not self.rotate_client_names or self.rotate_index < len(self.client_names): + return self.clients[self.client_names[self.rotate_index]] + try: + return self.clients[self.rotate_client_names[self.rotate_index]] + except IndexError as e: + logger.error(f'Index Error: {self.rotate_index} - {self.rotate_client_names}') + raise IndexError(f'Index Error: {self.rotate_index} - {self.rotate_client_names} - {self.client_names} ({len(self.clients)})') from e + + def increase_rotate_index(self): + """ + Increases the rotate index + """ + if self.rotate_index >= len(self.clients) - 1: + self.rotate_index = 0 + else: + self.rotate_index += 1 + + def rotate_client(self, index: Optional[int] = None, require_azure: Optional[bool] = None, verbose: Optional[bool] = False): + """ + Rotates the clients + """ + if index is not None: + self.rotate_index = index + return + self.increase_rotate_index() + if require_azure: + while not self.api.is_azure: + self.increase_rotate_index() + if verbose: + logger.info(f'Rotated Client: {self.api.name} (Azure: {self.api.is_azure} - {self.api.api_version}) [{self.rotate_index+1}/{len(self.clients)}]') + + def set_client(self, client_name: Optional[str] = None, verbose: Optional[bool] = False): + """ + Sets the client + """ + if client_name is None: + raise ValueError('`client_name` is required.') + if client_name not in self.clients: + raise ValueError(f'Client `{client_name}` does not exist.') + self.rotate_index = self.client_names.index(client_name) + if verbose: + logger.info(f'Set Client: {self.api.name} (Azure: {self.api.is_azure} - {self.api.api_version})) [{self.rotate_index+1}/{len(self.clients)}]') + + def current_client_info(self, verbose: Optional[bool] = False) -> Dict[str, Union[str, int]]: + """ + Returns the current client info + """ + data = { + 'name': self.api.name, + 'is_azure': self.api.is_azure, + 'api_version': self.api.api_version, + 'index': self.rotate_index, + 'total': len(self.clients), + } + if verbose: + logger.info(f'Current Client: {self.api.name} (Azure: {self.api.is_azure} - {self.api.api_version}) [{self.rotate_index+1}/{len(self.clients)}]') + return data + + + def configure_client(self, client_name: Optional[str] = None, priority: Optional[int] = None, **kwargs): + """ + Configure a new client + """ + client_name = client_name or 'default' + if client_name not in self.clients: + raise ValueError(f'Client `{client_name}` does not exist.') + self.clients[client_name].reset(**kwargs) + if priority is not None: + if client_name in self.rotate_client_names: + self.rotate_client_names.remove(client_name) + self.rotate_client_names.insert(priority, client_name) + + def init_api_client( + self, + client_name: Optional[str] = None, + set_as_default: Optional[bool] = False, + is_azure: Optional[bool] = None, + priority: Optional[int] = None, + set_as_current: Optional[bool] = False, + **kwargs + ) -> 'OpenAIClient': + """ + Creates a new OpenAI client. + """ + client_name = client_name or 'default' + if client_name in self.clients: + return self.clients[client_name] + + from async_openai.client import OpenAIClient + if is_azure is None and \ + ( + 'az' in client_name and self.settings.has_valid_azure + ): + is_azure = True + if 'client_callbacks' not in kwargs and \ + self.manager and \ + self.manager.client_callbacks: + kwargs['client_callbacks'] = self.manager.client_callbacks + client = OpenAIClient( + name = client_name, + settings = self.settings, + is_azure = is_azure, + azure_model_mapping = self.azure_model_mapping, + **kwargs + ) + self.clients[client_name] = client + if set_as_default: + self.rotate_client_names.insert(0, client_name) + elif priority is not None: + if client_name in self.rotate_client_names: + self.rotate_client_names.remove(client_name) + self.rotate_client_names.insert(priority, client_name) + elif self.prioritize: + if ( + self.prioritize == 'azure' + and is_azure + or self.prioritize != 'azure' + and self.prioritize == 'openai' + and not is_azure + ): + self.rotate_client_names.insert(0, client_name) + elif self.prioritize in ['azure', 'openai']: + self.rotate_client_names.append(client_name) + if set_as_current: + self.rotate_index = self.rotate_client_names.index(client_name) + return client + + def get_api_client(self, client_name: Optional[str] = None, require_azure: Optional[bool] = None, **kwargs) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one. + """ + if not client_name and not self.clients: + client_name = 'default' + if client_name and client_name not in self.clients: + self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + + if not client_name and require_azure: + while not self.api.is_azure: + self.increase_rotate_index() + return self.api + return self.clients[client_name] if client_name else self.api + + def get_api_client_from_list(self, client_names: List[str], require_azure: Optional[bool] = None, **kwargs) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one from a list of client names. + """ + if not self.healthcheck: + name = random.choice(client_names) + return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs) + for client_name in client_names: + if client_name not in self.clients: + self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + if require_azure and not self.clients[client_name].is_azure: + continue + if not self.clients[client_name].ping(): + continue + return self.clients[client_name] + raise ValueError(f'No healthy client found from: {client_names}') + + async def aget_api_client_from_list(self, client_names: List[str], require_azure: Optional[bool] = None, **kwargs) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one from a list of client names. + """ + if not self.healthcheck: + name = random.choice(client_names) + return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs) + for client_name in client_names: + if client_name not in self.clients: + self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + if require_azure and not self.clients[client_name].is_azure: + continue + if not await self.clients[client_name].aping(): + continue + return self.clients[client_name] + raise ValueError(f'No healthy client found from: {client_names}') + + def __getitem__(self, key: Union[str, int]) -> 'OpenAIClient': + """ + Returns a client by name. + """ + if isinstance(key, int): + key = self.rotate_client_names[key] if self.rotate_client_names else self.client_names[key] + return self.clients[key] \ No newline at end of file diff --git a/async_openai/manager.py b/async_openai/manager.py new file mode 100644 index 0000000..dcbf714 --- /dev/null +++ b/async_openai/manager.py @@ -0,0 +1,1312 @@ +from __future__ import annotations + +""" +OpenAI Session Manager +""" + +import abc +import copy +import pathlib +import random +from typing import Optional, List, Callable, Dict, Union, Any, overload, TYPE_CHECKING + +from async_openai.schemas import * +from async_openai.types.options import ApiType +from async_openai.types.context import ModelContextHandler +from async_openai.utils.config import get_settings, OpenAISettings +from async_openai.types.functions import FunctionManager, OpenAIFunctions +from async_openai.utils.logs import logger + +from .loadbalancer import ClientLoadBalancer + +if TYPE_CHECKING: + from async_openai.client import OpenAIClient + from lazyops.libs.pooler import ThreadPool + + + +# Model Mapping for Azure +DefaultModelMapping = { + 'gpt-3.5-turbo': 'gpt-35-turbo', + 'gpt-3.5-turbo-16k': 'gpt-35-turbo-16k', + 'gpt-3.5-turbo-instruct': 'gpt-35-turbo-instruct', + 'gpt-3.5-turbo-0301': 'gpt-35-turbo-0301', + 'gpt-3.5-turbo-0613': 'gpt-35-turbo-0613', + 'gpt-3.5-turbo-1106': 'gpt-35-turbo-1106', +} + +class OpenAIManager(abc.ABC): + name: Optional[str] = "openai" + on_error: Optional[Callable] = None + prioritize: Optional[str] = None + auto_healthcheck: Optional[bool] = None + auto_loadbalance_clients: Optional[bool] = None + azure_model_mapping: Optional[Dict[str, str]] = DefaultModelMapping + + _api: Optional['OpenAIClient'] = None + _apis: Optional['ClientLoadBalancer'] = None + _clients: Optional[Dict[str, 'OpenAIClient']] = {} + _settings: Optional[OpenAISettings] = None + + _pooler: Optional['ThreadPool'] = None + + """ + The Global Session Manager for OpenAI API. + """ + + def __init__(self, **kwargs): + """ + Initializes the OpenAI API Client + """ + self.client_model_exclusions: Optional[Dict[str, Dict[str, Union[bool, List[str]]]]] = {} + self.no_proxy_client_names: Optional[List[str]] = [] + self.client_callbacks: Optional[List[Callable]] = [] + self.functions: FunctionManager = OpenAIFunctions + if self.auto_loadbalance_clients is None: self.auto_loadbalance_clients = self.settings.auto_loadbalance_clients + if self.auto_healthcheck is None: self.auto_healthcheck = self.settings.auto_healthcheck + + def add_callback(self, callback: Callable): + """ + Adds a callback to the client + """ + self.client_callbacks.append(callback) + + @property + def settings(self) -> OpenAISettings: + """ + Returns the global settings for the OpenAI API. + """ + if self._settings is None: + self._settings = get_settings() + return self._settings + + # Changing the behavior to become proxied through settings + + @property + def api_key(self) -> Optional[str]: + """ + Returns the global API Key. + """ + return self.settings.api_key + + @property + def url(self) -> Optional[str]: + """ + Returns the global URL. + """ + return self.settings.url + + @property + def scheme(self) -> Optional[str]: + """ + Returns the global Scheme. + """ + return self.settings.scheme + + @property + def host(self) -> Optional[str]: + """ + Returns the global Host. + """ + return self.settings.host + + @property + def port(self) -> Optional[int]: + """ + Returns the global Port. + """ + return self.settings.port + + @property + def api_base(self) -> Optional[str]: + """ + Returns the global API Base. + """ + return self.settings.api_base + + @property + def api_path(self) -> Optional[str]: + """ + Returns the global API Path. + """ + return self.settings.api_path + + @property + def api_type(self) -> Optional[ApiType]: + """ + Returns the global API Type. + """ + return self.settings.api_type + + @property + def api_version(self) -> Optional[str]: + """ + Returns the global API Version. + """ + return self.settings.api_version + + @property + def api_key_path(self) -> Optional[pathlib.Path]: + """ + Returns the global API Key Path. + """ + return self.settings.api_key_path + + @property + def organization(self) -> Optional[str]: + """ + Returns the global Organization. + """ + return self.settings.organization + + @property + def proxies(self) -> Optional[Union[str, Dict]]: + """ + Returns the global Proxies. + """ + return self.settings.proxies + + @property + def timeout(self) -> Optional[int]: + """ + Returns the global Timeout. + """ + return self.settings.timeout + + @property + def max_retries(self) -> Optional[int]: + """ + Returns the global Max Retries. + """ + return self.settings.max_retries + + @property + def app_info(self) -> Optional[Dict[str, str]]: + """ + Returns the global App Info. + """ + return self.settings.app_info + + @property + def debug_enabled(self) -> Optional[bool]: + """ + Returns the global Debug Enabled. + """ + return self.settings.debug_enabled + + @property + def ignore_errors(self) -> Optional[bool]: + """ + Returns the global Ignore Errors. + """ + return self.settings.ignore_errors + + @property + def timeout(self) -> Optional[int]: + """ + Returns the global Timeout. + """ + return self.settings.timeout + + @property + def pooler(self) -> Optional['ThreadPool']: + """ + Returns the global ThreadPool. + """ + if self._pooler is None: + from lazyops.libs.pooler import ThreadPooler + self._pooler = ThreadPooler + return self._pooler + + + def configure_client( + self, + client_name: Optional[str] = None, + **kwargs, + ): + """ + Configure a specific client. + """ + if self.auto_loadbalance_clients: + return self.apis.configure_client(client_name = client_name, **kwargs) + client_name = client_name or 'default' + if client_name not in self._clients: + raise ValueError(f'Client `{client_name}` does not exist.') + self._clients[client_name].reset(**kwargs) + + def get_api_client( + self, + client_name: Optional[str] = None, + **kwargs, + ) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one. + """ + if self.auto_loadbalance_clients: + return self.apis.get_api_client(client_name = client_name, **kwargs) + client_name = client_name or 'default' + if client_name not in self._clients: + self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + return self._clients[client_name] + + + def get_api_client_from_list( + self, + client_names: Optional[List[str]] = None, + **kwargs, + ) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one. + """ + if self.auto_loadbalance_clients: + if not client_names: return self.apis.get_api_client(**kwargs) + return self.apis.get_api_client_from_list(client_names = client_names, **kwargs) + if not client_names: return self.get_api_client(**kwargs) + if not self.auto_healthcheck: + name = random.choice(client_names) + return self.get_api_client(client_name = name, **kwargs) + + for client_name in client_names: + if client_name not in self._clients: + self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + if not self._clients[client_name].ping(): + continue + return self._clients[client_name] + raise ValueError(f'No healthy client found from: {client_names}') + + async def aget_api_client_from_list( + self, + client_names: Optional[List[str]] = None, + **kwargs, + ) -> 'OpenAIClient': + """ + Initializes a new OpenAI client or Returns an existing one. + """ + if self.auto_loadbalance_clients: + if not client_names: return self.apis.get_api_client(**kwargs) + return await self.apis.aget_api_client_from_list(client_name = client_name, **kwargs) + if not client_names: return self.get_api_client(**kwargs) + if not self.auto_healthcheck: + name = random.choice(client_names) + return self.get_api_client(client_name = name, **kwargs) + + for client_name in client_names: + if client_name not in self._clients: + self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) + if not await self._clients[client_name].aping(): + continue + return self._clients[client_name] + raise ValueError(f'No healthy client found from: {client_names}') + + + def init_api_client( + self, + client_name: Optional[str] = None, + set_as_default: Optional[bool] = False, + is_azure: Optional[bool] = None, + **kwargs + ) -> 'OpenAIClient': + """ + Creates a new OpenAI client. + """ + if self.auto_loadbalance_clients: + return self.apis.init_api_client(client_name = client_name, set_as_default = set_as_default, is_azure = is_azure, **kwargs) + client_name = client_name or 'default' + if client_name in self._clients: + return self._clients[client_name] + + from async_openai.client import OpenAIClient + if is_azure is None and \ + ( + # (client_name == 'default' or 'az' in client_name) and + 'az' in client_name and self.settings.has_valid_azure + ): + is_azure = True + if 'client_callbacks' not in kwargs and self.client_callbacks: + kwargs['client_callbacks'] = self.client_callbacks + client = OpenAIClient( + name = client_name, + settings = self.settings, + is_azure = is_azure, + azure_model_mapping = self.azure_model_mapping, + **kwargs + ) + self._clients[client_name] = client + if set_as_default or not self._api: + self._api = client + return client + + def rotate_client(self, index: Optional[int] = None, verbose: Optional[bool] = False, **kwargs): + """ + Rotates the clients + """ + if not self.auto_loadbalance_clients: + raise ValueError('Rotating Clients is not enabled.') + self.apis.rotate_client(index = index, verbose = verbose, **kwargs) + + def set_client(self, client_name: Optional[str] = None, verbose: Optional[bool] = False): + """ + Sets the client + """ + if self.auto_loadbalance_clients: + self.apis.set_client(client_name = client_name, verbose = verbose) + else: + self._api = self._clients[client_name] + if verbose: + logger.info(f'Set Client: {self.api.name} ({self.api.is_azure})') + + def get_current_client_info(self, verbose: Optional[bool] = False) -> Dict[str, Union[str, int]]: + """ + Returns the current client info + """ + if self.auto_loadbalance_clients: + return self.apis.current_client_info(verbose = verbose) + data = { + 'name': self.api.name, + 'is_azure': self.api.is_azure, + 'api_version': self.api.api_version, + } + if verbose: + logger.info(f'Current Client: {self.api.name} (Azure: {self.api.is_azure} - {self.api.api_version})') + return data + + + @property + def apis(self) -> ClientLoadBalancer: + """ + Returns the global Rotating Clients. + """ + if self._apis is None: + self._apis = ClientLoadBalancer( + prioritize=self.prioritize, + settings=self.settings, + azure_model_mapping=self.azure_model_mapping, + healthcheck=self.auto_healthcheck, + manager = self, + ) + if self.settings.client_configurations: + self.register_client_endpoints() + else: + self.register_default_endpoints() + return self._apis + + @property + def api(self) -> 'OpenAIClient': + """ + Returns the inherited OpenAI client. + """ + if self.auto_loadbalance_clients: return self.apis.api + if self._api is None: + self.init_api_client() + return self._api + + def configure_internal_apis(self): + """ + Helper method to ensure that the APIs are initialized + """ + if self._api is not None: return + # Invoke it to ensure that it is initialized + if self.auto_loadbalance_clients: + self.apis + else: + self.init_api_client() + + def _ensure_api(self): + """ + Ensures that the API is initialized + """ + if self._api is None: self.configure_internal_apis() + + + + + """ + API Routes + """ + + @property + def completions(self) -> CompletionRoute: + """ + Returns the `CompletionRoute` class for interacting with `Completions`. + + Doc: `https://beta.openai.com/docs/api-reference/completions` + """ + return self.api.completions + + @property + def Completions(self) -> CompletionRoute: + """ + Returns the `CompletionRoute` class for interacting with `Completions`. + + Doc: `https://beta.openai.com/docs/api-reference/completions` + """ + return self.api.completions + + + @property + def chat(self) -> ChatRoute: + """ + Returns the `ChatRoute` class for interacting with `Chat`. + + Doc: `https://beta.openai.com/docs/api-reference/chat` + """ + return self.api.chat + + @property + def Chat(self) -> ChatRoute: + """ + Returns the `ChatRoute` class for interacting with `Chat`. + + Doc: `https://beta.openai.com/docs/api-reference/chat` + """ + return self.api.chat + + @property + def edits(self) -> EditRoute: + """ + Returns the `EditRoute` class for interacting with `Edits`. + + Doc: `https://beta.openai.com/docs/api-reference/edits` + """ + return self.api.edits + + @property + def embeddings(self) -> EmbeddingRoute: + """ + Returns the `EmbeddingRoute` class for interacting with `Embeddings`. + + Doc: `https://beta.openai.com/docs/api-reference/embeddings` + """ + return self.api.embeddings + + @property + def images(self) -> ImageRoute: + """ + Returns the `ImageRoute` class for interacting with `Images`. + + Doc: `https://beta.openai.com/docs/api-reference/images` + """ + return self.api.images + + @property + def models(self) -> ModelRoute: + """ + Returns the `ModelRoute` class for interacting with `models`. + + Doc: `https://beta.openai.com/docs/api-reference/models` + """ + return self.api.models + + + """ + V2 Endpoint Registration with Proxy Support + """ + + def register_default_endpoints(self): + """ + Register the default clients + """ + if self.settings.proxy.enabled: + api_base = self.settings.proxy.endpoint + az_custom_headers = { + "Helicone-OpenAI-Api-Base": self.settings.azure.api_base + } + self.configure( + api_base = api_base, + azure_api_base = api_base, + azure_custom_headers = az_custom_headers, + enable_rotating_clients = True, + prioritize = "azure", + ) + + self.init_api_client('openai', is_azure = False) + if self.settings.has_valid_azure: + self.init_api_client('azure', is_azure = True) + + + def register_client_endpoints(self): + """ + Register the Client Endpoints + """ + client_configs = copy.deepcopy(self.settings.client_configurations) + for name, config in client_configs.items(): + is_enabled = config.pop('enabled', False) + if not is_enabled: continue + is_azure = 'azure' in name or 'az' in name or config.get('is_azure', False) + is_default = config.pop('default', False) + proxy_disabled = config.pop('proxy_disabled', False) + source_endpoint = config.get('api_base') + if self.debug_enabled is not None: config['debug_enabled'] = self.debug_enabled + if excluded_models := config.pop('excluded_models', None): + self.client_model_exclusions[name] = { + 'models': excluded_models, 'is_azure': is_azure, + } + else: + self.client_model_exclusions[name] = { + 'models': None, 'is_azure': is_azure, + } + + if (self.settings.proxy.enabled and not proxy_disabled) and config.get('api_base'): + # Initialize a non-proxy version of the client + config['api_base'] = source_endpoint + non_proxy_name = f'{name}_noproxy' + self.client_model_exclusions[non_proxy_name] = self.client_model_exclusions[name].copy() + self.no_proxy_client_names.append(non_proxy_name) + self.init_api_client(non_proxy_name, is_azure = is_azure, set_as_default = False, **config) + config['headers'] = self.settings.proxy.create_proxy_headers( + name = name, + config = config, + ) + config['api_base'] = self.settings.proxy.endpoint + c = self.init_api_client(name, is_azure = is_azure, set_as_default = is_default, **config) + logger.info(f'Registered: `|g|{c.name}|e|` @ `{source_endpoint or c.base_url}` (Azure: {c.is_azure})', colored = True) + + + def select_client_names( + self, + client_name: Optional[str] = None, + azure_required: Optional[bool] = None, + openai_required: Optional[bool] = None, + model: Optional[str] = None, + noproxy_required: Optional[bool] = None, + excluded_clients: Optional[List[str]] = None, + ) -> Optional[List[str]]: + """ + Select Client based on the client name, azure_required, and model + """ + if client_name is not None: return client_name + oai_name = 'openai_noproxy' if noproxy_required else 'openai' + if openai_required: return [oai_name] + if model is not None: + available_clients = [] + for name, values in self.client_model_exclusions.items(): + if (noproxy_required and 'noproxy' not in name) or (not noproxy_required and 'noproxy' in name): continue + if excluded_clients and name in excluded_clients: continue + + # Prioritize Azure Clients + if ( + azure_required and not values['is_azure'] + ) or ( + not azure_required and not values['is_azure'] + ): + continue + if not values['models'] or model not in values['models']: + available_clients.append(name) + # return name + + if not available_clients: available_clients.append(oai_name) + return available_clients + if azure_required: + return [ + k for k, v in self.client_model_exclusions.items() if v['is_azure'] and \ + ('noproxy' in k if noproxy_required else 'noproxy' not in k) + ] + # return [k for k, v in self.client_model_exclusions.items() if v['is_azure']] + return None + + def get_client( + self, + client_name: Optional[str] = None, + azure_required: Optional[bool] = None, + openai_required: Optional[bool] = None, + model: Optional[str] = None, + noproxy_required: Optional[bool] = None, + excluded_clients: Optional[List[str]] = None, + **kwargs, + ) -> "OpenAIClient": + """ + Gets the OpenAI client + + Args: + client_name (str, optional): The name of the client to use. If not provided, it will be selected based on the other parameters. + azure_required (bool, optional): Whether the client must be an Azure client. + openai_required (bool, optional): Whether the client must be an OpenAI client. + model (str, optional): The model to use. If provided, the client will be selected based on the model. + noproxy_required (bool, optional): Whether the client must be a non-proxy client. + excluded_clients (List[str], optional): A list of client names to exclude from selection. + """ + self._ensure_api() + client_names = self.select_client_names( + client_name = client_name, + azure_required = azure_required, + openai_required = openai_required, + model = model, + noproxy_required = noproxy_required, + excluded_clients = excluded_clients + ) + client = self.get_api_client_from_list( + client_names = client_names, + azure_required = azure_required, + **kwargs + ) + if self.debug_enabled: + logger.info(f'Available Clients: {client_names} - Selected: {client.name}') + if not client_name and self.auto_loadbalance_clients: + self.apis.increase_rotate_index() + return client + + + def get_chat_client( + self, + client_name: Optional[str] = None, + azure_required: Optional[bool] = None, + openai_required: Optional[bool] = None, + model: Optional[str] = None, + noproxy_required: Optional[bool] = None, + excluded_clients: Optional[List[str]] = None, + **kwargs, + ) -> ChatRoute: + """ + Gets the chat client + """ + return self.get_client(client_name = client_name, azure_required = azure_required, openai_required = openai_required, model = model, noproxy_required = noproxy_required, excluded_clients = excluded_clients, **kwargs).chat + + def get_completion_client( + self, + client_name: Optional[str] = None, + azure_required: Optional[bool] = None, + openai_required: Optional[bool] = None, + model: Optional[str] = None, + noproxy_required: Optional[bool] = None, + excluded_clients: Optional[List[str]] = None, + **kwargs + ) -> CompletionRoute: + """ + Gets the chat client + """ + return self.get_client(client_name = client_name, azure_required = azure_required, openai_required = openai_required, model = model, noproxy_required = noproxy_required, excluded_clients = excluded_clients, **kwargs).completions + + def get_embedding_client( + self, + client_name: Optional[str] = None, + azure_required: Optional[bool] = None, + openai_required: Optional[bool] = None, + model: Optional[str] = None, + noproxy_required: Optional[bool] = None, + excluded_clients: Optional[List[str]] = None, + **kwargs + ) -> EmbeddingRoute: + """ + Gets the chat client + """ + return self.get_client(client_name = client_name, azure_required = azure_required, openai_required = openai_required, model = model, noproxy_required = noproxy_required, excluded_clients = excluded_clients, **kwargs).embeddings + + + """ + V2 Utilities + """ + def truncate_to_max_length( + self, + text: str, + model: str, + max_length: Optional[int] = None, + buffer_length: Optional[int] = None, + ) -> str: + """ + Truncates the text to the max length + """ + if max_length is None: + model_ctx = ModelContextHandler.get(model) + max_length = model_ctx.context_length + if buffer_length is not None: max_length -= buffer_length + + encoder = ModelContextHandler.get_tokenizer(model) + tokens = encoder.encode(text) + if len(tokens) > max_length: + tokens = tokens[-max_length:] + decoded = encoder.decode(tokens) + text = text[-len(decoded):] + return text + + def truncate_batch_to_max_length( + self, + texts: List[str], + model: str, + max_length: Optional[int] = None, + buffer_length: Optional[int] = None, + ) -> List[str]: + """ + Truncates the text to the max length + """ + if max_length is None: + model_ctx = ModelContextHandler.get(model) + max_length = model_ctx.context_length + if buffer_length is not None: max_length -= buffer_length + encoder = ModelContextHandler.get_tokenizer(model) + truncated_texts = [] + for text in texts: + tokens = encoder.encode(text) + if len(tokens) > max_length: + tokens = tokens[-max_length:] + decoded = encoder.decode(tokens) + text = text[-len(decoded):] + truncated_texts.append(text) + return truncated_texts + + async def atruncate_to_max_length( + self, + text: str, + model: str, + max_length: Optional[int] = None, + buffer_length: Optional[int] = None, + ) -> str: + """ + Truncates the text to the max length + """ + return await self.pooler.arun( + self.truncate_to_max_length, + text = text, + model = model, + max_length = max_length, + buffer_length = buffer_length, + ) + + async def atruncate_batch_to_max_length( + self, + texts: List[str], + model: str, + max_length: Optional[int] = None, + buffer_length: Optional[int] = None, + ) -> List[str]: + """ + Truncates the text to the max length + """ + return await self.pooler.arun( + self.truncate_batch_to_max_length, + texts = texts, + model = model, + max_length = max_length, + buffer_length = buffer_length, + ) + + + + """ + Context Managers + """ + + async def async_close(self): + """ + Closes the OpenAI API Client. + """ + for client in self._clients.values(): + await client.async_close() + + + def close(self): + """ + Closes the OpenAI API Client. + """ + for client in self._clients.values(): + client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.async_close() + + def __getitem__(self, key: Union[int, str]) -> 'OpenAIClient': + """ + Returns the OpenAI API Client. + """ + if self.auto_loadbalance_clients: + return self.apis[key] + if isinstance(key, int): + key = self.client_names[key] + return self._clients[key] + + @property + def client_names(self) -> List[str]: + """ + Returns the list of client names. + """ + return list(self._clients.keys()) + + + """ + Auto Rotating Functions + """ + + def chat_create( + self, + input_object: Optional[ChatObject] = None, + parse_stream: Optional[bool] = True, + auto_retry: Optional[bool] = False, + auto_retry_limit: Optional[int] = None, + verbose: Optional[bool] = False, + **kwargs + ) -> ChatResponse: + """ + Creates a chat response for the provided prompt and parameters + + Usage: + + ```python + >>> result = OpenAI.chat_create( + >>> messages = [{'content': 'say this is a test'}], + >>> max_tokens = 4, + >>> stream = True + >>> ) + ``` + + **Parameters:** + + :model (required): ID of the model to use. You can use the List models API + to see all of your available models, or see our Model overview for descriptions of them. + Default: `gpt-3.5-turbo` + + :messages: The messages to generate chat completions for, in the chat format. + + :max_tokens (optional): The maximum number of tokens to generate in the completion. + The token count of your prompt plus `max_tokens` cannot exceed the model's context length. + Most models have a context length of 2048 tokens (except for the newest models, which + support 4096 / 8182 / 32,768). If max_tokens is not provided, the model will use the maximum number of tokens + Default: None + + :temperature (optional): What sampling temperature to use. Higher values means + the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) + for ones with a well-defined answer. We generally recommend altering this or `top_p` but not both. + Default: `1.0` + + :top_p (optional): An alternative to sampling with `temperature`, called nucleus + sampling, where the model considers the results of the tokens with `top_p` probability mass. + So `0.1` means only the tokens comprising the top 10% probability mass are considered. + We generally recommend altering this or `temperature` but not both + Default: `1.0` + + :n (optional): How many completions to generate for each prompt. + Note: Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and stop. + Default: `1` + + :stream (optional): CURRENTLY NOT SUPPORTED + Whether to stream back partial progress. + If set, tokens will be sent as data-only server-sent events as they become + available, with the stream terminated by a `data: [DONE]` message. This is + handled automatically by the Client and enables faster response processing. + Default: `False` + + :logprobs (optional): Include the log probabilities on the `logprobs` + most likely tokens, as well the chosen tokens. For example, if `logprobs` is 5, + the API will return a list of the 5 most likely tokens. The API will always + return the logprob of the sampled token, so there may be up to `logprobs+1` + elements in the response. The maximum value for `logprobs` is 5. + Default: `None` + + :stop (optional): Up to 4 sequences where the API will stop generating + further tokens. The returned text will not contain the stop sequence. + Default: `None` + + :presence_penalty (optional): Number between `-2.0` and `2.0`. Positive values + penalize new tokens based on whether they appear in the text so far, increasing the + model's likelihood to talk about new topics + Default: `0.0` + + :frequency_penalty (optional): Number between `-2.0` and `2.0`. Positive + values penalize new tokens based on their existing frequency in the text so + far, decreasing the model's likelihood to repeat the same line verbatim. + Default: `0.0` + + :logit_bias (optional): Modify the likelihood of specified tokens appearing in the completion. + Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated + bias value from -100 to 100. You can use this tokenizer tool (which works for both GPT-2 and GPT-3) to + convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior + to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase + likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the + relevant token. + As an example, you can pass `{"50256": -100}` to prevent the `<|endoftext|>` token from being generated. + Default: `None` + + :user (optional): A unique identifier representing your end-user, which can help OpenAI to + monitor and detect abuse. + Default: `None` + + :functions (optional): A list of dictionaries representing the functions to call + + :function_call (optional): The name of the function to call. Default: `auto` if functions are provided + + :auto_retry (optional): Whether to automatically retry the request if it fails due to a rate limit error. + + :auto_retry_limit (optional): The maximum number of times to retry the request if it fails due to a rate limit error. + + Returns: `ChatResponse` + """ + + try: + return self.api.chat.create(input_object = input_object, parse_stream = parse_stream, auto_retry = auto_retry, auto_retry_limit = auto_retry_limit, **kwargs) + + except Exception as e: + if not self.auto_loadbalance_clients: raise e + self.rotate_client(verbose=verbose) + return self.chat_create(input_object = input_object, parse_stream = parse_stream, auto_retry = auto_retry, auto_retry_limit = auto_retry_limit, **kwargs) + + + async def async_chat_create( + self, + input_object: Optional[ChatObject] = None, + parse_stream: Optional[bool] = True, + auto_retry: Optional[bool] = False, + auto_retry_limit: Optional[int] = None, + verbose: Optional[bool] = False, + **kwargs + ) -> ChatResponse: + """ + Creates a chat response for the provided prompt and parameters + + Usage: + + ```python + >>> result = await OpenAI.async_chat_create( + >>> messages = [{'content': 'say this is a test'}], + >>> max_tokens = 4, + >>> stream = True + >>> ) + ``` + + **Parameters:** + + :model (required): ID of the model to use. You can use the List models API + to see all of your available models, or see our Model overview for descriptions of them. + Default: `gpt-3.5-turbo` + + :messages: The messages to generate chat completions for, in the chat format. + + :max_tokens (optional): The maximum number of tokens to generate in the completion. + The token count of your prompt plus `max_tokens` cannot exceed the model's context length. + Most models have a context length of 2048 tokens (except for the newest models, which + support 4096). + Default: `16` + + :temperature (optional): What sampling temperature to use. Higher values means + the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) + for ones with a well-defined answer. We generally recommend altering this or `top_p` but not both. + Default: `1.0` + + :top_p (optional): An alternative to sampling with `temperature`, called nucleus + sampling, where the model considers the results of the tokens with `top_p` probability mass. + So `0.1` means only the tokens comprising the top 10% probability mass are considered. + We generally recommend altering this or `temperature` but not both + Default: `1.0` + + :n (optional): How many completions to generate for each prompt. + Note: Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and stop. + Default: `1` + + :stream (optional): CURRENTLY NOT SUPPORTED + Whether to stream back partial progress. + If set, tokens will be sent as data-only server-sent events as they become + available, with the stream terminated by a `data: [DONE]` message. This is + handled automatically by the Client and enables faster response processing. + Default: `False` + + :logprobs (optional): Include the log probabilities on the `logprobs` + most likely tokens, as well the chosen tokens. For example, if `logprobs` is 5, + the API will return a list of the 5 most likely tokens. The API will always + return the logprob of the sampled token, so there may be up to `logprobs+1` + elements in the response. The maximum value for `logprobs` is 5. + Default: `None` + + :stop (optional): Up to 4 sequences where the API will stop generating + further tokens. The returned text will not contain the stop sequence. + Default: `None` + + :presence_penalty (optional): Number between `-2.0` and `2.0`. Positive values + penalize new tokens based on whether they appear in the text so far, increasing the + model's likelihood to talk about new topics + Default: `0.0` + + :frequency_penalty (optional): Number between `-2.0` and `2.0`. Positive + values penalize new tokens based on their existing frequency in the text so + far, decreasing the model's likelihood to repeat the same line verbatim. + Default: `0.0` + + :logit_bias (optional): Modify the likelihood of specified tokens appearing in the completion. + Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated + bias value from -100 to 100. You can use this tokenizer tool (which works for both GPT-2 and GPT-3) to + convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior + to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase + likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the + relevant token. + As an example, you can pass `{"50256": -100}` to prevent the `<|endoftext|>` token from being generated. + Default: `None` + + :user (optional): A unique identifier representing your end-user, which can help OpenAI to + monitor and detect abuse. + + :functions (optional): A list of dictionaries representing the functions to call + + :function_call (optional): The name of the function to call. Default: `auto` if functions are provided + + :auto_retry (optional): Whether to automatically retry the request if it fails due to a rate limit error. + + :auto_retry_limit (optional): The maximum number of times to retry the request if it fails due to a rate limit error. + + Default: `None` + + Returns: `ChatResponse` + """ + + try: + return await self.api.chat.async_create(input_object = input_object, parse_stream = parse_stream, auto_retry = auto_retry, auto_retry_limit = auto_retry_limit, **kwargs) + + except Exception as e: + if not self.auto_loadbalance_clients: raise e + self.rotate_client(verbose=verbose) + return await self.async_chat_create(input_object = input_object, parse_stream = parse_stream, auto_retry = auto_retry, auto_retry_limit = auto_retry_limit, **kwargs) + + achat_create = async_chat_create + + def create_embeddings( + self, + inputs: Union[str, List[str]], + model: Optional[str] = None, + auto_retry: Optional[bool] = True, + strip_newlines: Optional[bool] = False, + **kwargs, + ) -> List[List[float]]: + """ + Creates the embeddings + + Args: + inputs (Union[str, List[str]]): The input text or list of input texts. + model (str, optional): The model to use. Defaults to None. + auto_retry (bool, optional): Whether to automatically retry the request. Defaults to True. + strip_newlines (bool, optional): Whether to strip newlines from the input. Defaults to False. + """ + from lazyops.utils.helpers import split_into_batches + model = model or 'text-embedding-ada-002' + inputs = [inputs] if isinstance(inputs, str) else inputs + inputs = self.truncate_batch_to_max_length( + inputs, + model = model, + **kwargs + ) + if strip_newlines: inputs = [i.replace('\n', ' ') for i in inputs] + client = self.get_client(model = model, **kwargs) + if not client.is_azure: + response = client.embeddings.create(input = inputs, auto_retry = auto_retry, **kwargs) + return response.embeddings + + embeddings = [] + # We need to split into batches of 5 for Azure + # Azure has a limit of 5 inputs per request + batches = split_into_batches(inputs, 5) + for batch in batches: + response = client.embeddings.create(input = batch, auto_retry = auto_retry, **kwargs) + embeddings.extend(response.embeddings) + # Shuffle the clients to load balance + client = self.get_client(model = model, azure_required = True, **kwargs) + return embeddings + + + async def async_create_embeddings( + self, + inputs: Union[str, List[str]], + model: Optional[str] = None, + auto_retry: Optional[bool] = True, + strip_newlines: Optional[bool] = False, + **kwargs, + ) -> List[List[float]]: + """ + Creates the embeddings + + Args: + inputs (Union[str, List[str]]): The input text or list of input texts. + model (str, optional): The model to use. Defaults to None. + auto_retry (bool, optional): Whether to automatically retry the request. Defaults to True. + strip_newlines (bool, optional): Whether to strip newlines from the input. Defaults to False. + """ + from lazyops.utils.helpers import split_into_batches + model = model or 'text-embedding-ada-002' + inputs = [inputs] if isinstance(inputs, str) else inputs + inputs = await self.atruncate_batch_to_max_length( + inputs, + model = model, + **kwargs + ) + if strip_newlines: inputs = [i.replace('\n', ' ') for i in inputs] + client = self.get_client(model = model, **kwargs) + if not client.is_azure: + response = await client.embeddings.async_create(input = inputs, auto_retry = auto_retry, **kwargs) + return response.embeddings + + embeddings = [] + # We need to split into batches of 5 for Azure + # Azure has a limit of 5 inputs per request + batches = split_into_batches(inputs, 5) + for batch in batches: + response = await client.embeddings.async_create(input = batch, auto_retry = auto_retry, **kwargs) + embeddings.extend(response.embeddings) + # Shuffle the clients to load balance + client = self.get_client(model = model, azure_required = True, **kwargs) + return embeddings + + acreate_embeddings = async_create_embeddings + + + @overload + def configure( + self, + api_key: Optional[str] = None, + url: Optional[str] = None, + scheme: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + + api_base: Optional[str] = None, + api_path: Optional[str] = None, + api_type: Optional[ApiType] = None, + api_version: Optional[str] = None, + api_key_path: Optional[pathlib.Path] = None, + + organization: Optional[str] = None, + proxies: Optional[Union[str, Dict]] = None, + timeout: Optional[int] = None, + max_retries: Optional[int] = None, + app_info: Optional[Dict[str, str]] = None, + debug_enabled: Optional[bool] = None, + ignore_errors: Optional[bool] = None, + disable_retries: Optional[bool] = None, + max_connections: Optional[int] = None, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[int] = None, + custom_headers: Optional[Dict[str, str]] = None, + + on_error: Optional[Callable] = None, + reset: Optional[bool] = None, + prioritize: Optional[str] = None, + # enable_rotating_clients: Optional[bool] = None, + azure_model_mapping: Optional[Dict[str, str]] = None, + + auto_healthcheck: Optional[bool] = None, + auto_loadbalance_clients: Optional[bool] = None, + proxy_config: Optional[Union[Dict[str, Any], pathlib.Path]] = None, + client_configurations: Optional[Union[Dict[str, Dict[str, Any]], pathlib.Path]] = None, + **kwargs + ): + """ + Configure the global OpenAI client. + + :param url: The OpenAI API URL | Env: [`OPENAI_API_URL`] + :param scheme: The OpenAI API Scheme | Env: [`OPENAI_API_SCHEME`] + :param host: The OpenAI API Host | Env: [`OPENAI_API_HOST`] + :param port: The OpenAI API Port | Env: [`OPENAI_API_PORT`] + :param api_base: The OpenAI API Base | Env: [`OPENAI_API_BASE`] + :param api_key: The OpenAI API Key | Env: [`OPENAI_API_KEY`] + :param api_path: The OpenAI API Path | Env: [`OPENAI_API_PATH`] + :param api_type: The OpenAI API Type | Env: [`OPENAI_API_TYPE`] + :param api_version: The OpenAI API Version | Env: [`OPENAI_API_VERSION`] + :param api_key_path: The API Key Path | Env: [`OPENAI_API_KEY_PATH`] + :param organization: Organization | Env: [`OPENAI_ORGANIZATION`] + :param proxies: The OpenAI Proxies | Env: [`OPENAI_PROXIES`] + :param timeout: Timeout in Seconds | Env: [`OPENAI_TIMEOUT`] + :param max_retries: The OpenAI Max Retries | Env: [`OPENAI_MAX_RETRIES`] + :param ignore_errors: Ignore Errors | Env: [`OPENAI_IGNORE_ERRORS`] + :param disable_retries: Disable Retries | Env: [`OPENAI_DISABLE_RETRIES`] + :param max_connections: Max Connections | Env: [`OPENAI_MAX_CONNECTIONS`] + :param max_keepalive_connections: Max Keepalive Connections | Env: [`OPENAI_MAX_KEEPALIVE_CONNECTIONS`] + :param keepalive_expiry: Keepalive Expiry | Env: [`OPENAI_KEEPALIVE_EXPIRY`] + :param custom_headers: Custom Headers | Env: [`OPENAI_CUSTOM_HEADERS`] + + :param on_error: On Error Callback + :param kwargs: Additional Keyword Arguments + """ + ... + + @overload + def configure( + azure_api_key: Optional[str] = None, + azure_url: Optional[str] = None, + azure_scheme: Optional[str] = None, + azure_host: Optional[str] = None, + azure_port: Optional[int] = None, + + azure_api_base: Optional[str] = None, + azure_api_path: Optional[str] = None, + azure_api_type: Optional[ApiType] = None, + azure_api_version: Optional[str] = None, + azure_api_key_path: Optional[pathlib.Path] = None, + + azure_organization: Optional[str] = None, + azure_proxies: Optional[Union[str, Dict]] = None, + azure_timeout: Optional[int] = None, + azure_max_retries: Optional[int] = None, + azure_app_info: Optional[Dict[str, str]] = None, + azure_debug_enabled: Optional[bool] = None, + azure_ignore_errors: Optional[bool] = None, + azure_max_connections: Optional[int] = None, + azure_max_keepalive_connections: Optional[int] = None, + azure_keepalive_expiry: Optional[int] = None, + azure_custom_headers: Optional[Dict[str, str]] = None, + + on_error: Optional[Callable] = None, + reset: Optional[bool] = None, + prioritize: Optional[str] = None, + # enable_rotating_clients: Optional[bool] = None, + azure_model_mapping: Optional[Dict[str, str]] = None, + debug_enabled: Optional[bool] = None, + + auto_healthcheck: Optional[bool] = None, + auto_loadbalance_clients: Optional[bool] = None, + proxy_config: Optional[Union[Dict[str, Any], pathlib.Path]] = None, + client_configurations: Optional[Union[Dict[str, Dict[str, Any]], pathlib.Path]] = None, + **kwargs + ): + """ + Configure the global OpenAI client for Azure + + :param azure_url: The OpenAI API URL | Env: [`AZURE_OPENAI_API_URL`] + :param azure_scheme: The OpenAI API Scheme | Env: [`AZURE_OPENAI_API_SCHEME`] + :param azure_host: The OpenAI API Host | Env: [`AZURE_OPENAI_API_HOST`] + :param azure_port: The OpenAI API Port | Env: [`AZURE_OPENAI_API_PORT`] + :param azure_api_key: The OpenAI API Key | Env: [`AZURE_OPENAI_API_KEY`] + :param azure_api_base: The OpenAI API Base | Env: [`AZURE_OPENAI_API_BASE`] + :param azure_api_path: The OpenAI API Path | Env: [`AZURE_OPENAI_API_PATH`] + :param azure_api_type: The OpenAI API Type | Env: [`AZURE_OPENAI_API_TYPE`] + :param azure_api_version: The OpenAI API Version | Env: [`AZURE_OPENAI_API_VERSION`] + :param azure_api_key_path: The API Key Path | Env: [`AZURE_OPENAI_API_KEY_PATH`] + :param azure_organization: Organization | Env: [`AZURE_OPENAI_ORGANIZATION`] + :param azure_proxies: The OpenAI Proxies | Env: [`AZURE_OPENAI_PROXIES`] + :param azure_timeout: Timeout in Seconds | Env: [`AZURE_OPENAI_TIMEOUT`] + :param azure_max_retries: The OpenAI Max Retries | Env: [`AZURE_OPENAI_MAX_RETRIES`] + :param kwargs: Additional Keyword Arguments + """ + ... + + + def configure( + self, + on_error: Optional[Callable] = None, + prioritize: Optional[str] = None, + # enable_rotating_clients: Optional[bool] = None, + azure_model_mapping: Optional[Dict[str, str]] = None, + debug_enabled: Optional[bool] = None, + + auto_healthcheck: Optional[bool] = None, + auto_loadbalance_clients: Optional[bool] = None, + **kwargs + ): + """ + Configure the global OpenAI client. + """ + if on_error is not None: self.on_error = on_error + if prioritize is not None: self.prioritize = prioritize + if debug_enabled is not None: self.settings.debug_enabled = debug_enabled + # if enable_rotating_clients is not None: self.enable_rotating_clients = enable_rotating_clients + if auto_loadbalance_clients is not None: self.auto_loadbalance_clients = auto_loadbalance_clients + if auto_healthcheck is not None: self.auto_healthcheck = auto_healthcheck + if azure_model_mapping is not None: + self.azure_model_mapping = azure_model_mapping + for key, val in azure_model_mapping.items(): + ModelContextHandler.add_model(key, val) + self.settings.configure(auto_loadbalance_clients = auto_loadbalance_clients, auto_healthcheck = auto_healthcheck, **kwargs) diff --git a/async_openai/routes.py b/async_openai/routes.py index 6fccabd..43e00ca 100644 --- a/async_openai/routes.py +++ b/async_openai/routes.py @@ -1,6 +1,6 @@ import aiohttpx -from typing import Optional, Dict, Callable, TYPE_CHECKING +from typing import Optional, Dict, Callable, List, TYPE_CHECKING from async_openai.schemas import * from async_openai.utils.config import get_settings, OpenAISettings, AzureOpenAISettings from async_openai.utils.logs import logger @@ -46,6 +46,7 @@ def __init__( max_retries: Optional[int] = None, settings: Optional[OpenAISettings] = None, is_azure: Optional[bool] = None, + client_callbacks: Optional[List[Callable]] = None, **kwargs ): @@ -64,6 +65,8 @@ def __init__( self.is_azure = is_azure if is_azure is not None else \ isinstance(self.settings, AzureOpenAISettings) self.kwargs = kwargs or {} + if client_callbacks: + self.kwargs['client_callbacks'] = client_callbacks self.init_routes() diff --git a/async_openai/types/context.py b/async_openai/types/context.py index 35e2ab4..9af00e3 100644 --- a/async_openai/types/context.py +++ b/async_openai/types/context.py @@ -101,10 +101,25 @@ def model_aliases(cls) -> Dict[str, str]: cls._model_aliases = {alias: model for model, item in cls.models.items() for alias in item.aliases or []} return cls._model_aliases + def resolve_model_name(cls, model_name: str) -> str: + """ + Resolves the Model Name from the model aliases + """ + # Try to remove the version number + key = model_name.rsplit('-', 1)[0].strip() + if key in cls.model_aliases: + cls.model_aliases[model_name] = cls.model_aliases[key] + if key in cls.models: + cls.model_aliases[model_name] = key + return key + raise KeyError(f"Model {model_name} not found") + def __getitem__(cls, key: str) -> ModelCostItem: """ Gets a model by name """ + if key not in cls.model_aliases and key not in cls.models: + return cls.resolve_model_name(key) if key in cls.model_aliases: key = cls.model_aliases[key] return cls.models[key] diff --git a/async_openai/types/errors.py b/async_openai/types/errors.py index 49f8b6a..36256ed 100644 --- a/async_openai/types/errors.py +++ b/async_openai/types/errors.py @@ -139,6 +139,28 @@ def __repr__(self): +class MaxRetriesExhausted(Exception): + """ + Max Retries Exhausted + """ + + def __init__(self, name: str, func_name: str, model: str, attempts: int, max_attempts: int): + self.name = name + self.func_name = func_name + self.model = model + self.attempts = attempts + self.max_attempts = max_attempts + + def __str__(self): + return f"[{self.name} - {self.model}] All retries exhausted for {self.func_name}. ({self.attempts}/{self.max_attempts})" + + def __repr__(self): + """ + Returns the string representation of the error. + """ + return f"[{self.name} - {self.model}] (func_name={self.func_name}, attempts={self.attempts}, max_attempts={self.max_attempts})" + + class APIError(OpenAIError): pass @@ -221,7 +243,7 @@ def fatal_exception(exc) -> bool: # with 400, 404, 415 status codes (invalid request), # 400 can include invalid parameters, such as invalid `max_tokens` # don't retry on other client errors - if isinstance(exc, (InvalidMaxTokens, InvalidRequestError)): + if isinstance(exc, (InvalidMaxTokens, InvalidRequestError, MaxRetriesExhausted)): return True return (400 <= exc.status < 500) and exc.status not in [429, 400, 404, 415, 524] # [429, 400, 404, 415] diff --git a/async_openai/types/functions.py b/async_openai/types/functions.py new file mode 100644 index 0000000..e5930aa --- /dev/null +++ b/async_openai/types/functions.py @@ -0,0 +1,945 @@ +from __future__ import annotations + +""" +OpenAI Functions Base Class +""" + +import jinja2 +from abc import ABC +from pydantic import PrivateAttr, BaseModel +# from lazyops.types import BaseModel +from lazyops.utils.times import Timer +from lazyops.libs.proxyobj import ProxyObject +from async_openai.utils.fixjson import resolve_json +from . import errors + +from typing import Optional, Any, Dict, List, Union, Type, Tuple, Awaitable, TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from async_openai import ChatResponse, ChatRoute + from async_openai.manager import OpenAIManager as OpenAISessionManager + from lazyops.utils.logs import Logger + from lazyops.libs.persistence import PersistentDict + + +FT = TypeVar('FT', bound = BaseModel) +SchemaT = TypeVar('SchemaT', bound = BaseModel) + + +class BaseFunctionModel(BaseModel): + _name: Optional[str] = PrivateAttr(None) + + def update( + self, + values: 'BaseFunctionModel', + ): + """ + Updates the values + """ + pass + + def _setup_item( + self, + item: 'SchemaT', + **kwargs + ) -> 'SchemaT': + """ + Updates the Reference Item + """ + return item + + + def update_values( + self, + item: 'SchemaT', + **kwargs + ) -> 'SchemaT': + """ + Updates the Reference Item with the values + """ + return item + + + def update_data( + self, + item: 'SchemaT', + **kwargs + ) -> 'SchemaT': + """ + Updates the Reference Item with the values + """ + item = self._setup_item(item = item, **kwargs) + item = self.update_values(item = item, **kwargs) + return item + + + def is_valid(self) -> bool: + """ + Returns whether the function data is valid + """ + return True + +FunctionSchemaT = TypeVar('FunctionSchemaT', bound = BaseFunctionModel) +FunctionResultT = TypeVar('FunctionResultT', bound = BaseFunctionModel) + +class BaseFunction(ABC): + """ + Base Class for OpenAI Functions + """ + + name: Optional[str] = None + function_name: Optional[str] = None + description: Optional[str] = None + schema: Optional[Type[FunctionSchemaT]] = None + schemas: Optional[Dict[str, Dict[str, Union[str, Type[FunctionSchemaT]]]]] = None + + prompt_template: Optional[str] = None + system_template: Optional[str] = None + + default_model: Optional[str] = 'gpt-35-turbo' + default_larger_model: Optional[bool] = None + cachable: Optional[bool] = True + result_buffer: Optional[int] = 1000 + retry_limit: Optional[int] = 5 + max_attempts: Optional[int] = 2 + + default_model_local: Optional[str] = None + default_model_develop: Optional[str] = None + default_model_production: Optional[str] = None + + + def __init__( + self, + api: Optional['OpenAISessionManager'] = None, + debug_enabled: Optional[bool] = None, + **kwargs + ): + """ + This gets initialized from the Enrichment Handler + """ + from async_openai.manager import ModelContextHandler + from async_openai.utils.logs import logger, null_logger + self.ctx: Type['ModelContextHandler'] = ModelContextHandler + if api is None: + from async_openai.client import OpenAIManager + api = OpenAIManager + + self.api: 'OpenAISessionManager' = api + self.pool = self.api.pooler + self.kwargs = kwargs + self.logger = logger + self.null_logger = null_logger + self.settings = self.api.settings + if debug_enabled is not None: + self.debug_enabled = debug_enabled + else: + self.debug_enabled = self.settings.debug_enabled + self.build_funcs(**kwargs) + self.build_templates(**kwargs) + self.post_init(**kwargs) + + @property + def default_model_func(self) -> str: + """ + Returns the default model + """ + if self.settings.is_local_env: + return self.default_model_local or self.default_model + if self.settings.is_development_env: + return self.default_model_develop or self.default_model + return self.default_model_production or self.default_model + + @property + def autologger(self) -> 'Logger': + """ + Returns the logger + """ + return self.logger if \ + (self.debug_enabled or self.settings.is_development_env) else self.null_logger + + + @property + def has_diff_model_than_default(self) -> bool: + """ + Returns True if the default model is different than the default model + """ + return self.default_model_func != self.default_model + + + def build_templates(self, **kwargs): + """ + Construct the templates + """ + self.template = self.create_template(self.prompt_template) + # Only create the system template if it's a jinja template + if self.system_template and '{%' in self.system_template: + self.system_template = self.create_template(self.system_template) + + def build_funcs(self, **kwargs): + """ + Builds the functions + """ + # Handles multi functions + if self.schemas: + self.functions = [] + self.functions.extend( + { + "name": name, + "description": data.get('description', self.description), + "parameters": data.get('schema', self.schema), + } + for name, data in self.schemas.items() + ) + else: + self.functions = [ + { + "name": self.function_name or self.name, + "description": self.description, + "parameters": self.schema, + } + ] + + + def post_init(self, **kwargs): + """ + Post Init Hook + """ + pass + + def pre_call_hook(self, *args, **kwargs): + """ + Pre Call Hook + """ + pass + + async def apre_call_hook(self, *args, **kwargs): + """ + Pre Call Hook + """ + self.pre_call_hook(*args, **kwargs) + + def pre_validate(self, *args, **kwargs) -> bool: + """ + Validate the input before running + """ + return True + + async def apre_validate(self, *args, **kwargs) -> bool: + """ + Validate the input before running + """ + return self.pre_validate(*args, **kwargs) + + + def pre_validate_model(self, prompt: str, model: str, *args, **kwargs) -> str: + """ + Validates the model before running + """ + return model + + async def apre_validate_model(self, prompt: str, model: str, *args, **kwargs) -> str: + """ + Validates the model before running + """ + return self.pre_validate_model(prompt = prompt, model = model, *args, **kwargs) + + + + def call( + self, + *args, + model: Optional[str] = None, + **kwargs + ) -> Optional[FunctionSchemaT]: + """ + Call the function + """ + if not self.pre_validate(*args, **kwargs): + return None + self.pre_call_hook(*args, **kwargs) + return self.run_function(*args, model = model, **kwargs) + + async def acall( + self, + *args, + model: Optional[str] = None, + **kwargs + ) -> Optional[FunctionSchemaT]: + """ + Call the function + """ + if not await self.apre_validate(*args, **kwargs): + return None + await self.apre_call_hook(*args, **kwargs) + return await self.arun_function(*args, model = model, **kwargs) + + def __call__( + self, + *args, + model: Optional[str] = None, + is_async: Optional[bool] = True, + **kwargs + ) -> Optional[FunctionSchemaT]: + """ + Call the function + """ + if is_async: return self.acall(*args, model = model, **kwargs) + return self.call(*args, model = model, **kwargs) + + def get_chat_client(self, model: str, **kwargs) -> 'ChatRoute': + """ + Gets the chat client + """ + return self.api.get_chat_client(model = model, **kwargs) + + def get_completion_client(self, model: str, **kwargs) -> 'ChatRoute': + """ + Gets the chat client + """ + return self.api.get_chat_client(model = model, **kwargs) + + async def arun_chat_function( + self, + messages: List[Dict[str, Any]], + chat: Optional['ChatRoute'] = None, + cachable: Optional[bool] = None, + functions: Optional[List[Dict[str, Any]]] = None, + function_name: Optional[str] = None, + property_meta: Optional[Dict[str, Any]] = None, + model: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + excluded_clients: Optional[List[str]] = None, + **kwargs, + ) -> ChatResponse: # sourcery skip: low-code-quality + """ + Runs the chat function + """ + current_attempt = kwargs.pop('_current_attempt', 0) + last_chat_name = kwargs.pop('_last_chat_name', None) + if current_attempt and current_attempt > self.retry_limit: + raise errors.MaxRetriesExhausted( + name = last_chat_name, + func_name = function_name or self.name, + model = model, + attempts = current_attempt, + max_attempts = self.retry_limit, + ) + + disable_cache = not cachable if cachable is not None else not self.cachable + if not chat: + if last_chat_name: + if not excluded_clients: excluded_clients = [] + excluded_clients.append(last_chat_name) + chat = self.get_chat_client(model = model, excluded_clients = excluded_clients, **kwargs) + if not headers and 'noproxy' not in chat.name: + headers = { + 'Helicone-Cache-Enabled': 'false' if disable_cache else 'true', + 'Helicone-Property-FunctionName': function_name or self.name, + } + if property_meta: + property_meta = {f'Helicone-Property-{k}': str(v) for k, v in property_meta.items()} + headers.update(property_meta) + + elif headers and 'noproxy' in chat.name: + headers = None + functions = functions or self.functions + try: + if headers: chat.client.headers.update(headers) + return await chat.async_create( + model = model, + messages = messages, + functions = functions, + headers = headers, + auto_retry = True, + auto_retry_limit = 2, + function_call = {'name': function_name or self.name}, + header_cache_keys = ['Helicone-Cache-Enabled'], + **kwargs, + ) + except errors.InvalidRequestError as e: + self.logger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Invalid Request Error. |r|{e}|e|", colored=True) + raise e + except errors.MaxRetriesExceeded as e: + self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Retrying...", colored=True) + return await self.arun_chat_function( + messages = messages, + cachable = cachable, + functions = functions, + function_name = function_name, + property_meta = property_meta, + model = model, + headers = headers, + excluded_clients = excluded_clients, + _current_attempt = current_attempt + 1, + _last_chat_name = chat.name, + **kwargs, + ) + except Exception as e: + self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Unknown Error Trying to run chat function: |r|{e}|e|", colored=True) + return await self.arun_chat_function( + messages = messages, + cachable = cachable, + functions = functions, + function_name = function_name, + property_meta = property_meta, + model = model, + headers = headers, + excluded_clients = excluded_clients, + _current_attempt = current_attempt + 1, + _last_chat_name = chat.name, + **kwargs, + ) + + def run_chat_function( + self, + chat: 'ChatRoute', + messages: List[Dict[str, Any]], + cachable: Optional[bool] = None, + functions: Optional[List[Dict[str, Any]]] = None, + function_name: Optional[str] = None, + property_meta: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> ChatResponse: + """ + Runs the chat function + """ + disable_cache = not cachable if cachable is not None else not self.cachable + headers = None + if 'noproxy' not in chat.name: + headers = { + 'Helicone-Cache-Enabled': 'false' if disable_cache else 'true', + 'Helicone-Property-FunctionName': function_name or self.name, + } + if property_meta: + property_meta = {f'Helicone-Property-{k.strip()}': str(v).strip() for k, v in property_meta.items()} + headers.update(property_meta) + if headers: chat.client.headers.update(headers) + functions = functions or self.functions + return chat.create( + messages = messages, + functions = functions, + headers = headers, + auto_retry = True, + auto_retry_limit = self.retry_limit, + function_call = {'name': function_name or self.name}, + header_cache_keys=['Helicone-Cache-Enabled'], + **kwargs, + ) + + def parse_response( + self, + response: 'ChatResponse', + schema: Optional[Type[FunctionSchemaT]] = None, + include_name: Optional[bool] = True, + ) -> Optional[FunctionSchemaT]: + """ + Parses the response + """ + schema = schema or self.schema + try: + result = schema.model_validate(response.function_results[0].arguments, from_attributes = True) + if include_name: + result._name = self.name + return result + except Exception as e: + self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object: {e}\n{response.text}\n{response.function_results[0].arguments}") + try: + result = schema.model_validate(resolve_json(response.function_results[0].arguments), from_attributes = True) + if include_name: + result._name = self.name + return result + except Exception as e: + self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object after fixing") + return None + + + def is_valid_response(self, response: FT) -> bool: + """ + Returns True if the response is valid + """ + return True + + def apply_text_cleaning(self, text: str) -> str: + """ + Applies text cleaning + """ + from lazyops.utils.format_utils import clean_html, clean_text, cleanup_dots + if "..." in text: text = cleanup_dots(text) + return clean_html(clean_text(text)) + + @staticmethod + def create_template(template: str, enable_async: Optional[bool] = False, **kwargs) -> jinja2.Template: + """ + Creates the template + """ + return jinja2.Template(template, enable_async = enable_async, **kwargs) + + def truncate_documents( + self, + documents: Dict[str, str], + max_length: Optional[int] = None, + buffer_size: Optional[int] = None, + model: Optional[str] = None, + truncation_length: Optional[int] = None, + ) -> Dict[str, str]: + """ + Helper Function to truncate supporting docs + """ + current_length = 0 + if max_length is None: + model = model or self.default_model_func + max_length = self.ctx.get(model).context_length + if buffer_size is None: buffer_size = self.result_buffer + max_length -= buffer_size + + truncation_length = truncation_length or (max_length // len(documents)) + new_documents = {} + for file_name, file_text in documents.items(): + if not file_text: continue + file_text = self.apply_text_cleaning(file_text)[:truncation_length] + current_length += len(file_text) + new_documents[file_name] = file_text + if current_length > max_length: break + return new_documents + + """ + Function Handlers + """ + + def prepare_function_inputs( + self, + model: Optional[str] = None, + **kwargs + ) -> Tuple[List[Dict[str, Any]], str]: + """ + Prepare the Function Inputs for the function + """ + model = model or self.default_model_func + prompt = self.template.render(**kwargs) + prompt = self.api.truncate_to_max_length(prompt, model = model, buffer_length = self.result_buffer) + messages = [] + if self.system_template: + if isinstance(self.system_template, jinja2.Template): + system_template = self.system_template.render(**kwargs) + else: + system_template = self.system_template + messages.append({ + "role": "system", + "content": system_template, + }) + messages.append({ + "role": "user", + "content": prompt, + }) + return messages, model + + async def aprepare_function_inputs( + self, + model: Optional[str] = None, + **kwargs + ) -> Tuple[List[Dict[str, Any]], str]: + """ + Prepare the Function Inputs for the function + """ + model = model or self.default_model_func + prompt = self.template.render(**kwargs) + prompt = await self.api.atruncate_to_max_length(prompt, model = model, buffer_length = self.result_buffer) + messages = [] + if self.system_template: + if isinstance(self.system_template, jinja2.Template): + system_template = self.system_template.render(**kwargs) + else: + system_template = self.system_template + messages.append({ + "role": "system", + "content": system_template, + }) + messages.append({ + "role": "user", + "content": prompt, + }) + return messages, model + + def run_function( + self, + *args, + model: Optional[str] = None, + **kwargs + ) -> Optional[FunctionSchemaT]: + """ + Returns the Function Result + """ + messages, model = self.prepare_function_inputs(model = model, **kwargs) + return self.run_function_loop(messages = messages, model = model, **kwargs) + + async def arun_function( + self, + *args, + model: Optional[str] = None, + **kwargs + ) -> Optional[FunctionSchemaT]: + """ + Returns the Function Result + """ + messages, model = await self.aprepare_function_inputs(model = model, **kwargs) + return await self.arun_function_loop(messages = messages, model = model, **kwargs) + + """ + Handle a Loop + """ + + def run_function_loop( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + raise_errors: Optional[bool] = True, + **kwargs, + ) -> Optional[FunctionSchemaT]: + """ + Runs the function loop + """ + chat = self.get_chat_client(model = model, **kwargs) + response = self.run_chat_function( + chat = chat, + messages = messages, + model = model, + **kwargs, + ) + + result = self.parse_response(response, include_name = True) + if result is not None: return result + + # Try Again + attempts = 1 + _ = kwargs.pop('cachable', None) + while attempts < self.max_attempts: + chat = self.get_chat_client(model = model, **kwargs) + response = self.run_chat_function( + chat = chat, + messages = messages, + model = model, + **kwargs, + ) + result = self.parse_response(response, include_name = True) + if result is not None: return result + attempts += 1 + self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.") + if raise_errors: raise errors.MaxRetriesExhausted(name = self.name, attempts = self.max_attempts) + return None + + async def arun_function_loop( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + raise_errors: Optional[bool] = True, + **kwargs, + ) -> Optional[FunctionSchemaT]: + """ + Runs the function loop + """ + chat = self.get_chat_client(model = model, **kwargs) + response = await self.arun_chat_function( + chat = chat, + messages = messages, + model = model, + **kwargs, + ) + + result = self.parse_response(response, include_name = True) + if result is not None: return result + + # Try Again + attempts = 1 + _ = kwargs.pop('cachable', None) + while attempts < self.max_attempts: + chat = self.get_chat_client(model = model, **kwargs) + response = await self.arun_chat_function( + chat = chat, + messages = messages, + model = model, + **kwargs, + ) + result = self.parse_response(response, include_name = True) + if result is not None: return result + attempts += 1 + self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.") + if raise_errors: raise errors.MaxRetriesExhausted(name = self.name, attempts = self.max_attempts) + return None + + + +FunctionT = TypeVar('FunctionT', bound = BaseFunction) + + + +class FunctionManager(ABC): + """ + The Functions Manager Class that handles registering and managing functions + + - Additionally supports caching through `kvdb` + """ + + name: Optional[str] = 'functions' + + def __init__( + self, + **kwargs, + ): + from async_openai.utils.config import settings + from async_openai.utils.logs import logger, null_logger + self.logger = logger + self.null_logger = null_logger + self.settings = settings + self.debug_enabled = self.settings.debug_enabled + self.cache_enabled = self.settings.function_cache_enabled + + self._api: Optional['OpenAISessionManager'] = None + self._cache: Optional['PersistentDict'] = None + self.functions: Dict[str, BaseFunction] = {} + self._kwargs = kwargs + try: + import xxhash + self._hash_func = xxhash.xxh64 + except ImportError: + from hashlib import md5 + self._hash_func = md5 + + try: + import cloudpickle + self._pickle = cloudpickle + except ImportError: + import pickle + self._pickle = pickle + + @property + def api(self) -> 'OpenAISessionManager': + """ + Returns the API + """ + if self._api is None: + from async_openai.client import OpenAIManager + self._api = OpenAIManager + return self._api + + @property + def autologger(self) -> 'Logger': + """ + Returns the logger + """ + return self.logger if \ + (self.debug_enabled or self.settings.is_development_env) else self.null_logger + + @property + def cache(self) -> 'PersistentDict': + """ + Gets the cache + """ + if self._cache is None: + serializer_kwargs = { + 'compression': self._kwargs.get('serialization_compression', None), + 'compression_level': self._kwargs.get('serialization_compression_level', None), + 'raise_errors': True, + } + kwargs = { + 'base_key': f'openai.functions.{self.api.settings.app_env.name}.{self.api.settings.proxy.proxy_app_name or "default"}', + 'expiration': self._kwargs.get('cache_expiration', 60 * 60 * 24 * 3), + 'serializer': self._kwargs.get('serialization', 'json'), + 'serializer_kwargs': serializer_kwargs, + } + try: + import kvdb + self._cache = kvdb.create_persistence(session_name = 'openai', **kwargs) + except ImportError: + from lazyops.libs.persistence import PersistentDict + self._cache = PersistentDict(**kwargs) + return self._cache + + def register_function( + self, + func: Union[BaseFunction, Type[BaseFunction], str], + name: Optional[str] = None, + overwrite: Optional[bool] = False, + raise_error: Optional[bool] = False, + **kwargs, + ): + """ + Registers the function + """ + if isinstance(func, str): + from lazyops.utils.lazy import lazy_import + func = lazy_import(func) + if isinstance(func, type): + func = func(**kwargs) + name = name or func.name + if not overwrite and name in self.functions: + if raise_error: raise ValueError(f"Function {name} already exists") + return + self.functions[name] = func + self.autologger.info(f"Registered Function: |g|{name}|e|", colored=True) + + def create_hash(self, **kwargs) -> str: + """ + Creates a hash + """ + return self._hash_func(self._pickle.dumps(kwargs)).hexdigest() + + async def acreate_hash(self, **kwargs) -> str: + """ + Creates a hash + """ + return await self.api.pooler.asyncish(self.create_hash, **kwargs) + + def get(self, name: Union[str, 'FunctionT']) -> Optional['FunctionT']: + """ + Gets the function + """ + return name if isinstance(name, BaseFunction) else self.functions.get(name) + + def execute( + self, + function: Union['FunctionT', str], + *args, + item_hashkey: Optional[str] = None, + cachable: Optional[bool] = True, + overrides: Optional[List[str]] = None, + **function_kwargs + ) -> Optional['FunctionSchemaT']: + """ + Runs the function + """ + overwrite = overrides and 'functions' in overrides + function = self.get(function) + if overwrite and self.check_value_present(overrides, f'{function.name}.cachable'): + cachable = False + + if item_hashkey is None: item_hashkey = self.create_hash(**function_kwargs) + key = f'{item_hashkey}.{function.name}' + if function.has_diff_model_than_default: + key += f'.{function.default_model_func}' + + t = Timer() + result = None + cache_hit = False + if self.cache_enabled and not overwrite: + result: 'FunctionResultT' = self.cache.fetch(key) + if result: + if isinstance(result, dict): result = function.schema.model_validate(result) + result._name = function.name + cache_hit = True + + if result is None: + result = function(*args, cachable = cachable, is_async = False, **function_kwargs) + if self.cache_enabled and function.is_valid_response(result): + self.cache.set(key, result) + + self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True) + return result if function.is_valid_response(result) else None + + + async def aexecute( + self, + function: Union['FunctionT', str], + *args, + item_hashkey: Optional[str] = None, + cachable: Optional[bool] = True, + overrides: Optional[List[str]] = None, + **function_kwargs + ) -> Optional['FunctionSchemaT']: + """ + Runs the function + """ + overwrite = overrides and 'functions' in overrides + function = self.get(function) + if overwrite and self.check_value_present(overrides, f'{function.name}.cachable'): + cachable = False + + if item_hashkey is None: item_hashkey = await self.acreate_hash(**function_kwargs) + key = f'{item_hashkey}.{function.name}' + if function.has_diff_model_than_default: + key += f'.{function.default_model_func}' + + t = Timer() + result = None + cache_hit = False + if self.cache_enabled and not overwrite: + result: 'FunctionResultT' = await self.cache.afetch(key) + if result: + if isinstance(result, dict): result = function.schema.model_validate(result) + result._name = function.name + cache_hit = True + + if result is None: + result = await function(*args, cachable = cachable, is_async = True, **function_kwargs) + if self.cache_enabled and function.is_valid_response(result): + await self.cache.aset(key, result) + + self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True) + return result if function.is_valid_response(result) else None + + + + @property + def function_names(self) -> List[str]: + """ + Returns the function names + """ + return list(self.functions.keys()) + + def __getitem__(self, name: str) -> Optional['FunctionT']: + """ + Gets the function + """ + return self.get(name) + + def __setitem__(self, name: str, value: Union[FunctionT, Type[FunctionT], str]): + """ + Sets the function + """ + return self.register_function(value, name = name) + + def append(self, value: Union[FunctionT, Type[FunctionT], str]): + """ + Appends the function + """ + return self.register_function(value) + + + def check_value_present( + self, items: List[str], *values: str, + ) -> bool: + """ + Checks if the value is present + """ + if not values: + return any(self.name in item for item in items) + for value in values: + key = f'{self.name}.{value}' if value else self.name + if any((key in item or value in item) for item in items): + return True + return False + + def __call__( + self, + function: Union['FunctionT', str], + *args, + item_hashkey: Optional[str] = None, + cachable: Optional[bool] = True, + overrides: Optional[List[str]] = None, + is_async: Optional[bool] = True, + **function_kwargs + ) -> Union[Awaitable['FunctionSchemaT'], 'FunctionSchemaT']: + """ + Runs the function + """ + method = self.aexecute if is_async else self.execute + return method( + function = function, + *args, + item_hashkey = item_hashkey, + cachable = cachable, + overrides = overrides, + **function_kwargs + ) + + +OpenAIFunctions: FunctionManager = ProxyObject(FunctionManager) \ No newline at end of file diff --git a/async_openai/types/pricing.yaml b/async_openai/types/pricing.yaml index 9392edf..3dded0a 100644 --- a/async_openai/types/pricing.yaml +++ b/async_openai/types/pricing.yaml @@ -12,7 +12,20 @@ gpt-4-1106-preview: endpoints: - chat +gpt-4-0125-preview: + aliases: + - gpt-4-turbo-preview + context_length: 128000 + costs: + unit: 1000 + input: 0.01 + output: 0.03 + endpoints: + - chat + gpt-4: + aliases: + - gpt-4-0613 context_length: 8192 costs: unit: 1000 @@ -22,6 +35,8 @@ gpt-4: - chat gpt-4-32k: + aliases: + - gpt-4-32k-0613 context_length: 32768 costs: unit: 1000 @@ -62,6 +77,7 @@ gpt-3.5-turbo-16k: # - gpt-35-16k # - gpt-35-turbo-16k - gpt-3.5-turbo-16k + - gpt-3.5-turbo-16k-0613 context_length: 16384 costs: unit: 1000 @@ -109,6 +125,17 @@ gpt-3.5-turbo-0613: endpoints: - chat +gpt-3.5-turbo-0125: + aliases: + - gpt-3.5-0125 + context_length: 16384 + costs: + unit: 1000 + input: 0.0005 + output: 0.0015 + endpoints: + - chat + gpt-3.5-turbo-instruct: aliases: - gpt-3.5-instruct @@ -134,5 +161,19 @@ text-embedding-ada-002: endpoints: - embeddings +text-embedding-3-large: + context_length: 8191 + costs: + unit: 1000 + input: 0.00013 + endpoints: + - embeddings +text-embedding-3-small: + context_length: 8191 + costs: + unit: 1000 + input: 0.00002 + endpoints: + - embeddings diff --git a/async_openai/types/routes.py b/async_openai/types/routes.py index 04358ce..0b6fecf 100644 --- a/async_openai/types/routes.py +++ b/async_openai/types/routes.py @@ -70,6 +70,8 @@ class BaseRoute(BaseModel): is_azure: Optional[bool] = None azure_model_mapping: Optional[Dict[str, str]] = None + client_callbacks: Optional[List[Callable]] = None + @lazyproperty def api_resource(self): """ @@ -220,9 +222,13 @@ async def async_create( stream = input_object.get('stream'), **kwargs ) + # if input_object.get('stream'): + # await api_response.aread() data = self.handle_response(api_response) return await self.aprepare_response(data, input_object = input_object, parse_stream = parse_stream) + acreate = async_create + def batch_create( self, input_object: Optional[Type[BaseResource]] = None, @@ -298,6 +304,8 @@ async def async_batch_create( ) resp = self.handle_response(api_response) return await self.aprepare_response(resp, input_object = input_object) + + abatch_create = async_batch_create def retrieve( @@ -354,6 +362,8 @@ async def async_retrieve( ) data = self.handle_response(api_response) return self.prepare_response(data) + + aretrieve = async_retrieve def get( self, @@ -385,6 +395,7 @@ async def async_get( """ return await self.async_retrieve(resource_id = resource_id, params = params, headers = headers, **kwargs) + aget = async_get def list( self, @@ -440,6 +451,8 @@ async def async_list( ) data = self.handle_response(api_response) return await self.aprepare_response(data) + + alist = async_list def get_all( self, @@ -468,6 +481,8 @@ async def async_get_all( :return: Dict[str, Union[List[Type[BaseResource]], Dict[str, Any]]] """ return await self.async_retrieve(params = params, **kwargs) + + aget_all = async_get_all def delete( self, @@ -519,6 +534,7 @@ async def async_delete( data = self.handle_response(api_response) return self.prepare_response(data) + adelete = async_delete def update( self, @@ -605,8 +621,7 @@ async def async_update( data = self.handle_response(api_response) return self.prepare_response(data, input_object = input_object) - - + aupdate = async_update """ Extra Methods @@ -641,6 +656,8 @@ async def async_exists( return await self.async_get(resource_id = resource_id, **kwargs) except Exception: return False + + aexists = async_exists def upsert( self, @@ -702,7 +719,7 @@ async def async_upsert( return resource return await self.async_create(input_object = input_object, **kwargs) - + aupsert = async_upsert def upload( self, @@ -776,6 +793,8 @@ async def async_upload( data = self.handle_response(api_response) return self.prepare_response(data, input_object = input_object) + aupload = async_upload + def download( self, resource_id: str, @@ -826,6 +845,8 @@ async def async_download( data = self.handle_response(api_response) return self.prepare_response(data) + adownload = async_download + def prepare_response( self, data: aiohttpx.Response, @@ -842,10 +863,11 @@ def prepare_response( """ response_object = response_object or self.response_model if response_object: - return response_object.prepare_response(data, input_object = input_object, parse_stream = parse_stream) + response = response_object.prepare_response(data, input_object = input_object, parse_stream = parse_stream) + self.handle_callbacks(response, **kwargs) + return response raise NotImplementedError('Response model not defined for this resource.') - async def aprepare_response( self, data: aiohttpx.Response, @@ -862,9 +884,28 @@ async def aprepare_response( """ response_object = response_object or self.response_model if response_object: - return await response_object.aprepare_response(data, input_object = input_object, parse_stream = parse_stream) + response = await response_object.aprepare_response(data, input_object = input_object, parse_stream = parse_stream) + self.handle_callbacks(response, **kwargs) + return response raise NotImplementedError('Response model not defined for this resource.') + def handle_callbacks( + self, + response_object: BaseResource, + **kwargs + ): + """ + Handle the Callbacks for the Response as a Background Task + + This is useful for when you want to run a background task after a response is received + + The callback should be a function that takes the response object as the first argument + """ + if self.client_callbacks: + from lazyops.libs.pooler import ThreadPooler + for callback in self.client_callbacks: + ThreadPooler.background(callback, response_object, **kwargs) + def handle_response( self, response: aiohttpx.Response, diff --git a/async_openai/utils/config.py b/async_openai/utils/config.py index d6141c1..ac28bd7 100644 --- a/async_openai/utils/config.py +++ b/async_openai/utils/config.py @@ -2,8 +2,11 @@ import logging import pathlib import aiohttpx +import contextlib from typing import Optional, Dict, Union, Any from lazyops.types import BaseSettings, validator, BaseModel, lazyproperty, Field +from lazyops.libs.proxyobj import ProxyObject +from lazyops.libs.abcs.configs.types import AppEnv from async_openai.version import VERSION from async_openai.types.options import ApiType @@ -118,6 +121,7 @@ class BaseOpenAISettings(BaseSettings): keepalive_expiry: Optional[int] = 60 custom_headers: Optional[Dict[str, str]] = None + limit_monitor_enabled: Optional[bool] = True @validator("api_type") def validate_api_type(cls, v): @@ -420,7 +424,7 @@ class AzureOpenAISettings(BaseOpenAISettings): """ api_type: Optional[ApiType] = ApiType.azure - api_version: Optional[str] = "2023-07-01-preview" + api_version: Optional[str] = "2023-12-01-preview" api_path: Optional[str] = None class Config: @@ -437,15 +441,170 @@ def is_valid(self) -> bool: ) + +class OpenAIProxySettings(BaseSettings): + + proxy_enabled: Optional[bool] = None + proxy_endpoint: Optional[str] = None + + proxy_name: Optional[str] = None + proxy_kind: Optional[str] = 'helicone' + proxy_env_name: Optional[str] = None + proxy_app_name: Optional[str] = None + proxy_endpoints: Optional[Dict[str, str]] = Field(default_factory = dict) + proxy_apikeys: Optional[Dict[str, str]] = Field(default_factory = dict) + + @property + def endpoint(self) -> Optional[str]: + """ + Returns the Proxy Endpoint + """ + return self.proxy_endpoint + + @property + def enabled(self) -> Optional[bool]: + """ + Returns whether the proxy is enabled + """ + return self.proxy_enabled + + def get_proxy_endpoint(self) -> Optional[str]: + """ + Returns the proxy endpoint + """ + if self.proxy_name and self.proxy_endpoints.get(self.proxy_name): + return self.proxy_endpoints[self.proxy_name] + for name, endpoint in self.proxy_endpoints.items(): + with contextlib.suppress(Exception): + resp = aiohttpx.get(endpoint, timeout = 2.0) + # data = resp.json() + # if data.get('error'): + self.proxy_name = name + return endpoint + return None + + def init(self, config_path: Optional[pathlib.Path] = None): + """ + Initializes the core settings + """ + if config_path: self.load_proxy_config(config_path) + if self.proxy_endpoint is None: + self.proxy_endpoint = self.get_proxy_endpoint() + self.proxy_enabled = self.proxy_endpoint is not None + + def get_apikey( + self, source: Optional[str] = None, + ) -> str: + """ + Gets the appropriate API Key for the proxy + """ + if source: + source = source.lower() + for k, v in self.proxy_apikeys.items(): + if k in source: return v + return self.proxy_apikeys.get('default', None) + + def load_proxy_config( + self, + path: pathlib.Path, + ): + """ + Loads the Proxy Configuration from a File + """ + if not path.exists(): return + data: Dict[str, Union[Dict[str, str], str]] = json.loads(path.read_text()) + for k, v in data.items(): + if v is None: continue + if k in {'endpoint', 'enabled'}: k = f'proxy_{k}' + if hasattr(self, k): setattr(self, k, v) + elif hasattr(self, f'proxy_{k}'): setattr(self, f'proxy_{k}', v) + self.proxy_endpoint = None + self.proxy_enabled = None + + def update( + self, + **kwargs + ) -> None: + """ + Updates the Proxy Settings + """ + for k, v in kwargs.items(): + if v is None: continue + if k in {'endpoint', 'enabled'}: k = f'proxy_{k}' + if hasattr(self, k): setattr(self, k, v) + elif hasattr(self, f'proxy_{k}'): setattr(self, f'proxy_{k}', v) + self.proxy_endpoint = None + self.proxy_enabled = None + + + def create_proxy_headers_for_helicone( + self, + name: str, + config: Dict[str, Any], + **properties: Dict[str, str], + ) -> Dict[str, Any]: + """ + Creates the Proxy Headers for Helicone + """ + headers = { + 'Helicone-OpenAI-Api-Base': config.get('api_base', ''), + 'Helicone-Auth': f"Bearer {self.get_apikey(self.proxy_app_name)}", + "Helicone-Property-ClientName": name, + 'Content-Type': 'application/json', + } + user_id = '' + if self.proxy_app_name: + headers['Helicone-Property-AppName'] = self.proxy_app_name + user_id += self.proxy_app_name + if self.proxy_env_name: + headers['Helicone-Property-AppEnvironment'] = self.proxy_env_name + if user_id: user_id += f'-{self.proxy_env_name}' + if user_id: headers['Helicone-User-Id'] = user_id + if 'properties' in config: properties = config.pop('properties') + if properties: + for k, v in properties.items(): + if 'Helicone-Property-' not in k: k = f'Helicone-Property-{k}' + headers[k] = str(v) + return headers + + def create_proxy_headers( + self, + name: str, + config: Dict[str, Any], + kind: Optional[str] = None, + **properties: Dict[str, str], + ) -> Dict[str, Any]: + """ + Creates the Proxy Headers + """ + if kind is None: kind = self.proxy_kind + if kind == 'helicone': + return self.create_proxy_headers_for_helicone(name, config, **properties) + raise ValueError(f"Unsupported Proxy Kind: {kind}") + + class Config: + # We use a different prefix here to avoid conflicts + env_prefix = "OAI_" + case_sensitive = False + + class OpenAISettings(BaseOpenAISettings): """ The OpenAI Settings """ + app_env: Optional[AppEnv] = None + client_configurations: Optional[Dict[str, Dict[str, Any]]] = Field(default_factory = dict) + auto_loadbalance_clients: Optional[bool] = True + auto_healthcheck: Optional[bool] = True + + function_cache_enabled: Optional[bool] = True + class Config: env_prefix = 'OPENAI_' case_sensitive = False + @lazyproperty def azure(self) -> AzureOpenAISettings: """ @@ -453,6 +612,13 @@ def azure(self) -> AzureOpenAISettings: """ return AzureOpenAISettings() + @lazyproperty + def proxy(self) -> OpenAIProxySettings: + """ + Return the Proxy Settings + """ + return OpenAIProxySettings() + @property def has_valid_azure(self) -> bool: """ @@ -460,9 +626,25 @@ def has_valid_azure(self) -> bool: """ return self.azure.is_valid + def load_client_configurations( + self, + path: pathlib.Path, + ): + """ + Loads the Client Configurations + """ + if not path.exists(): return + data: Dict[str, Dict[str, Any]] = json.loads(path.read_text()) + self.client_configurations.update(data) def configure( self, + auto_healthcheck: Optional[bool] = None, + auto_loadbalance_clients: Optional[bool] = None, + proxy_app_name: Optional[str] = None, + proxy_env_name: Optional[str] = None, + proxy_config: Optional[Union[Dict[str, Any], pathlib.Path]] = None, + client_configurations: Optional[Union[Dict[str, Dict[str, Any]], pathlib.Path]] = None, **kwargs ): """ @@ -494,6 +676,19 @@ def configure( :param max_retries: The OpenAI Max Retries | Env: [`OPENAI_MAX_RETRIES`] :param kwargs: Additional Keyword Arguments """ + if auto_healthcheck is not None: self.auto_healthcheck = auto_healthcheck + if auto_loadbalance_clients is not None: self.auto_loadbalance_clients = auto_loadbalance_clients + if proxy_config: + if isinstance(proxy_config, pathlib.Path): + self.proxy.load_proxy_config(proxy_config) + else: self.proxy.update(**proxy_config) + self.proxy.init() + if proxy_app_name: self.proxy.proxy_app_name = proxy_app_name + if proxy_env_name: self.proxy.proxy_name = proxy_env_name + if client_configurations: + if isinstance(client_configurations, pathlib.Path): + self.load_client_configurations(client_configurations) + else: self.client_configurations.update(client_configurations) # Parse apart the azure setting configurations az_kwargs, rm_keys = {}, [] @@ -511,29 +706,64 @@ def configure( for k in rm_keys: kwargs.pop(k, None) super().configure(**kwargs) + + @validator('app_env', pre=True) + def validate_app_env(cls, value: Optional[Any]) -> Any: + """ + Validates the app environment + """ + if value is None: + from lazyops.libs.abcs.configs.base import get_app_env + return get_app_env(cls.__module__) + return AppEnv.from_env(value) if isinstance(value, str) else value + + @property + def in_k8s(self) -> bool: + """ + Returns whether the app is running in kubernetes + """ + from lazyops.utils.system import is_in_kubernetes + return is_in_kubernetes() + + @property + def is_local_env(self) -> bool: + """ + Returns whether the environment is development + """ + return self.app_env in [AppEnv.DEVELOPMENT, AppEnv.LOCAL] and not self.in_k8s + + @property + def is_production_env(self) -> bool: + """ + Returns whether the environment is production + """ + return self.app_env == AppEnv.PRODUCTION and self.in_k8s + @property + def is_development_env(self) -> bool: + """ + Returns whether the environment is development + """ + return self.app_env in [AppEnv.DEVELOPMENT, AppEnv.LOCAL, AppEnv.CICD] -_settings: Optional[OpenAISettings] = None + +settings: OpenAISettings = ProxyObject(OpenAISettings) def get_settings(**kwargs) -> OpenAISettings: """ Returns the OpenAI Settings """ - global _settings - if _settings is None: - _settings = OpenAISettings() - if kwargs: _settings.configure(**kwargs) - return _settings + if kwargs: settings.configure(**kwargs) + return settings def get_default_headers() -> Dict[str, Any]: """ Returns the Default Headers """ - return get_settings().get_headers() - + return settings.get_headers() def get_max_retries() -> int: """ Returns the Max Retries """ - return get_settings().max_retries \ No newline at end of file + return settings.max_retries diff --git a/async_openai/utils/logs.py b/async_openai/utils/logs.py index 75e0293..9b2331b 100644 --- a/async_openai/utils/logs.py +++ b/async_openai/utils/logs.py @@ -1,5 +1,5 @@ import os -from lazyops.utils.logs import get_logger, change_logger_level +from lazyops.utils.logs import get_logger, change_logger_level, null_logger # to prevent recursive imports, we'll just use os environ here if os.getenv('DEBUG_ENABLED') == 'True': diff --git a/async_openai/version.py b/async_openai/version.py index 0a31f9b..36231bd 100644 --- a/async_openai/version.py +++ b/async_openai/version.py @@ -1 +1 @@ -VERSION = '0.0.42' \ No newline at end of file +VERSION = '0.0.50rc0' \ No newline at end of file diff --git a/setup.py b/setup.py index 32afd60..a38ebb0 100644 --- a/setup.py +++ b/setup.py @@ -12,19 +12,21 @@ version = root.joinpath('async_openai/version.py').read_text().split('VERSION = ', 1)[-1].strip().replace('-', '').replace("'", '') requirements = [ - 'aiohttpx', + 'aiohttpx >= 0.0.12', # 'file-io', 'backoff', 'tiktoken', - 'lazyops >= 0.2.60', # Pydantic Support + 'lazyops >= 0.2.72', # Pydantic Support 'pydantic', + 'jinja2', # 'pydantic-settings', # remove to allow for v1/v2 support ] if sys.version_info.minor < 8: requirements.append('typing_extensions') -extras = {} +extras = { +} args = { 'packages': find_packages(include = [f'{pkg_name}', f'{pkg_name}.*',]),