diff --git a/ollama/_client.py b/ollama/_client.py index cbe43c94..839cc253 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -1,3 +1,4 @@ +import contextlib import ipaddress import json import os @@ -70,7 +71,7 @@ T = TypeVar('T') -class BaseClient: +class BaseClient(contextlib.AbstractContextManager, contextlib.AbstractAsyncContextManager): def __init__( self, client, @@ -105,6 +106,12 @@ def __init__( **kwargs, ) + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' @@ -113,6 +120,9 @@ class Client(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.Client, host, **kwargs) + def close(self): + self._client.close() + def _request_raw(self, *args, **kwargs): try: r = self._client.request(*args, **kwargs) @@ -617,6 +627,9 @@ class AsyncClient(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.AsyncClient, host, **kwargs) + async def close(self): + await self._client.aclose() + async def _request_raw(self, *args, **kwargs): try: r = await self._client.request(*args, **kwargs) diff --git a/tests/test_client.py b/tests/test_client.py index dacb953d..c93060aa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1140,3 +1140,33 @@ async def test_async_client_connection_error(): with pytest.raises(ConnectionError) as exc_info: await client.show('model') assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' + + +def test_client_close(): + client = Client() + client.close() + assert client._client.is_closed + + +@pytest.mark.asyncio +async def test_async_client_close(): + client = AsyncClient() + await client.close() + assert client._client.is_closed + + +def test_client_context_manager(): + with Client() as client: + assert isinstance(client, Client) + assert not client._client.is_closed + + assert client._client.is_closed + + +@pytest.mark.asyncio +async def test_async_client_context_manager(): + async with AsyncClient() as client: + assert isinstance(client, AsyncClient) + assert not client._client.is_closed + + assert client._client.is_closed