diff --git a/varlink/client.py b/varlink/client.py index 2b60e7e..fd92b5c 100644 --- a/varlink/client.py +++ b/varlink/client.py @@ -9,6 +9,7 @@ import sys import tempfile import threading +from typing import Optional from .error import InterfaceNotFound, VarlinkEncoder, VarlinkError from .scanner import Interface, _Method @@ -40,11 +41,11 @@ def __init__(self, interface, namespaced=False): if isinstance(member, _Method): self._add_method(member) - def close(self): + def close(self) -> None: """To be implemented.""" raise NotImplementedError - def _send_message(self, out): + def _send_message(self, out) -> None: """To be implemented. This should send a varlink message to the varlink service adding a trailing zero byte. @@ -205,7 +206,7 @@ def __enter__(self): def __exit__(self, _type, _value, _traceback): self.close() - def close(self): + def close(self) -> None: try: if hasattr(self._connection, "shutdown"): self._connection.shutdown(socket.SHUT_RDWR) @@ -214,7 +215,7 @@ def close(self): self._connection.close() - def _send_message(self, out): + def _send_message(self, out) -> None: if self._send_bytes: self._connection.send_bytes(out + b"\0") elif self._sendall: @@ -443,7 +444,7 @@ def new_bridge_socket(): self._child_pid = p.pid return sp[0] - def new_bridge_socket_compat(): + def new_bridge_socket_compat() -> socket.socket: sp = socket.socketpair() p = subprocess.Popen( " ".join(argv), shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, close_fds=True @@ -546,7 +547,7 @@ def _with_resolved_interface(self, interface, resolver_address=None): return self - def cleanup(self): + def cleanup(self) -> None: if hasattr(self, "_tmpdir") and self._tmpdir is not None: try: shutil.rmtree(self._tmpdir) @@ -563,7 +564,9 @@ def cleanup(self): except Exception: # TODO: maybe just ChildProcessError? pass - def open(self, interface_name, namespaced=False, connection=None): + def open( + self, interface_name: str, namespaced: bool = False, connection: Optional[socket.socket] = None + ) -> SimpleClientInterfaceHandler: """Open a new connection and get a client interface handle with the varlink methods installed. :param interface_name: an interface name, which the service this client object is @@ -586,11 +589,12 @@ def open(self, interface_name, namespaced=False, connection=None): return self.handler(self._interfaces[interface_name], connection, namespaced=namespaced) - def open_connection(self): + def open_connection(self) -> socket.socket: """Open a new connection and return the socket. :exception OSError: anything socket.connect() throws """ + assert self._socket_fn, "socket_fn not initialised" return self._socket_fn() def get_interfaces(self, socket_connection=None): @@ -628,7 +632,7 @@ def get_interface(self, interface_name, socket_connection=None): return interface - def add_interface(self, interface): + def add_interface(self, interface: Interface) -> None: """Manually add or overwrite an interface definition from an Interface object. :param interface: an Interface() object diff --git a/varlink/mock.py b/varlink/mock.py index 1a452c7..c86a9d6 100644 --- a/varlink/mock.py +++ b/varlink/mock.py @@ -16,7 +16,7 @@ def cast_type(typeof): return cast.get(typeof, typeof) -def get_ignored(): +def get_ignored() -> list[str]: ignore = dir(MockedService) return ignore @@ -248,7 +248,7 @@ def __init__( product="mock", version=1, url="http://localhost", - ): + ) -> None: if not name: module = service.__module__ try: @@ -258,7 +258,7 @@ def __init__( else: self.name = name self.identifier = str(uuid.uuid4()) - self.interface_description = None + self.interface_description: list[str] = [] self.service = service self.types = types self.address = address @@ -278,7 +278,7 @@ def __init__( } self.generate_interface() - def generate_interface(self): + def generate_interface(self) -> None: ignore = get_ignored() self.interface_description = [f"interface {self.name}"] if self.types: @@ -288,25 +288,25 @@ def generate_interface(self): for attr in attributs["callables"]: self.interface_description.append(generate_callable_interface(self.service, attr)) - def get_interface_file_path(self): + def get_interface_file_path(self) -> str: return f"/tmp/{self.name}" - def generate_interface_file(self): + def generate_interface_file(self) -> None: tfp = open(self.get_interface_file_path(), "w+") tfp.write("\n".join(self.interface_description)) tfp.close() - def delete_interface_files(self): + def delete_interface_files(self) -> None: os.remove(self.get_interface_file_path()) os.remove(self.mocked_service_file) - def service_start(self): + def service_start(self) -> None: self.service_pid = subprocess.Popen( [sys.executable, self.mocked_service_file], env={"PYTHONPATH": ":".join(sys.path)} ) time.sleep(2) - def service_stop(self): + def service_stop(self) -> None: self.service_pid.kill() self.service_pid.communicate() diff --git a/varlink/server.py b/varlink/server.py index 8031cd5..f8a128c 100644 --- a/varlink/server.py +++ b/varlink/server.py @@ -12,7 +12,7 @@ from socketserver import ForkingMixIn from types import GeneratorType -from typing import Optional +from typing import Optional, Union class Service: @@ -91,7 +91,7 @@ def GetInfo(self): "interfaces": list(self.interfaces.keys()), } - def GetInterfaceDescription(self, interface): + def GetInterfaceDescription(self, interface: str) -> dict[str, str]: """The standardized org.varlink.service.GetInterfaceDescription() varlink method.""" try: i = self.interfaces[interface] @@ -282,7 +282,7 @@ def decorator(interface_class): return decorator -def get_listen_fd(): +def get_listen_fd() -> Union[int, None]: if "LISTEN_FDS" not in os.environ: return None if "LISTEN_PID" not in os.environ: @@ -319,6 +319,8 @@ def get_listen_fd(): except OSError: return None + return None + class RequestHandler(StreamRequestHandler): """Varlink request handler @@ -500,7 +502,7 @@ def server_close(self): pass self.socket.close() - def fileno(self): + def fileno(self) -> int: """Return socket file number. Interface required by selector. diff --git a/varlink/tests/test_basic_network.py b/varlink/tests/test_basic_network.py index fcabe3b..611c80e 100755 --- a/varlink/tests/test_basic_network.py +++ b/varlink/tests/test_basic_network.py @@ -43,18 +43,39 @@ def do_run(self, address): server.shutdown() server.server_close() - def test_tcp(self): + def test_tcp(self) -> None: self.do_run("tcp:127.0.0.1:23450") - def test_anon_unix(self): + def test_anon_unix(self) -> None: if platform.startswith("linux"): self.do_run(f"unix:@org.varlink.service_anon_test{os.getpid()}{threading.current_thread().name}") - def test_unix(self): + def test_unix(self) -> None: if hasattr(socket, "AF_UNIX"): self.do_run(f"unix:org.varlink.service_anon_test_{os.getpid()}{threading.current_thread().name}") - def test_wrong_url(self): + def test_wrong_url(self) -> None: self.assertRaises( ConnectionError, self.do_run, f"uenix:org.varlink.service_wrong_url_test_{os.getpid()}" ) + + def test_reuse_open(self) -> None: + address = "tcp:127.0.0.1:23450" + server = varlink.ThreadingServer(address, ServiceRequestHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + try: + with varlink.Client(address) as client: + connection = client.open_connection() + re_use = client.open("org.varlink.service", False, connection) + + info = re_use.GetInfo() + self.assertEqual(len(info["interfaces"]), 1) + self.assertEqual(info["interfaces"][0], "org.varlink.service") + self.assertEqual(info, service.GetInfo()) + connection.close() + finally: + server.shutdown() + server.server_close() diff --git a/varlink/tests/test_orgexamplemore.py b/varlink/tests/test_orgexamplemore.py index 46ae92b..561e231 100755 --- a/varlink/tests/test_orgexamplemore.py +++ b/varlink/tests/test_orgexamplemore.py @@ -88,7 +88,7 @@ def __init__(self, reason): @service.interface("org.example.more") class Example: - sleep_duration = 1 + sleep_duration = 1.0 def TestMore(self, n, _more=True, _server=None): try: @@ -214,7 +214,7 @@ def epilog(): class TestService(unittest.TestCase): - def test_service(self): + def test_service(self) -> None: address = "tcp:127.0.0.1:23451" Example.sleep_duration = 0.1