diff --git a/openvpn_api/events/__init__.py b/openvpn_api/events/__init__.py new file mode 100644 index 0000000..c7f1737 --- /dev/null +++ b/openvpn_api/events/__init__.py @@ -0,0 +1,17 @@ +import importlib +import typing + +from openvpn_api.events import client + +event_types = [importlib.import_module(".client", __name__)] + +_callbacks = [] + + +def raise_event(event: typing.Type) -> None: + for callback in _callbacks: + callback(event) + + +def register_callback(callback: typing.Callable[[typing.Type], typing.Any]) -> None: + _callbacks.append(callback) diff --git a/openvpn_api/events/client.py b/openvpn_api/events/client.py new file mode 100644 index 0000000..38c98de --- /dev/null +++ b/openvpn_api/events/client.py @@ -0,0 +1,93 @@ +import re +from typing import List, Dict + +from openvpn_api.util import errors + +EVENT_TYPE_REGEXES = { + "CONNECT": re.compile(r"^>CLIENT:CONNECT,(?P([^,]+)),(?P([^,]+))$"), + "REAUTH": re.compile(r"^>CLIENT:REAUTH,(?P([^,]+)),(?P([^,]+))$"), + "ESTABLISHED": re.compile(r"^>CLIENT:ESTABLISHED,(?P([^,]+))$"), + "DISCONNECT": re.compile(r"^>CLIENT:DISCONNECT,(?P([^,]+))$"), + "ADDRESS": re.compile(r"^>CLIENT:ADDRESS,(?P([^,]+)),(?P([^,]+)),(?P([^,]+))$"), +} + +FIRST_LINE_REGEX = re.compile(r"^>CLIENT:(?P([^,]+))(.*)$") +ENV_REGEX = re.compile(r">CLIENT:ENV,(?P([^=]+))=(?P(.+))") + + +class ClientEvent: + def __init__(self, event_type, cid=None, kid=None, pri=None, addr=None, environment: Dict[str, str] = dict): + self.type = event_type + self.cid = int(cid) if cid is not None else None + self.kid = int(kid) if kid is not None else None + self.pri = int(pri) if pri is not None else None + self.addr = int(addr) if addr is not None else None + self.environment = environment + + +def is_input_began(line: str) -> bool: + if not line: + return False + + match = FIRST_LINE_REGEX.match(line) + if not match: + return False + + event_type = match.group("event") + if event_type not in EVENT_TYPE_REGEXES: + return False + + return True + + +def is_input_ended(line: str) -> bool: + return line and (line.strip().startswith(">CLIENT:ADDRESS,") or line.strip() == ">CLIENT:ENV,END") + + +def parse_raw(lines: List[str]) -> "ClientEvent": + if not lines: + raise errors.ParseError("Event raw input is empty.") + + first_line = lines.pop(0) + match = FIRST_LINE_REGEX.match(first_line) + + if not match: + raise errors.ParseError("Syntax error in first line of client event (Line: %s)" % first_line) + + event_type = match.group("event") + + if event_type not in EVENT_TYPE_REGEXES: + raise errors.ParseError( + "This event type (%s) is not supported (Supported events: %s)" % (event_type, EVENT_TYPE_REGEXES) + ) + + match = EVENT_TYPE_REGEXES[event_type].match(first_line) + + if not match: + raise errors.ParseError("Syntax error in first line of client event (Line: %s)" % first_line) + + first_line_data = match.groupdict() + cid = int(first_line_data["CID"]) if "CID" in first_line_data else None + kid = int(first_line_data["KID"]) if "KID" in first_line_data else None + pri = int(first_line_data["KID"]) if "KID" in first_line_data else None + addr = int(first_line_data["ADDR"]) if "ADDR" in first_line_data else None + environment = {} + + if event_type != "ADDRESS": + + for line in lines: + if line.strip() == ">CLIENT:ENV,END": + break + + match = ENV_REGEX.match(line) + if not match: + raise errors.ParseError("Invalid line in client event (Line: %s)" % line) + + environment[match.group("key")] = match.group("value") + else: + raise errors.ParseError("The raw event doesn't have an >CLIENT:ENV,END line.") + + if not environment: + raise errors.ParseError("This event type (%s) doesn't support empty environment." % event_type) + + return ClientEvent(event_type=event_type, cid=cid, kid=kid, pri=pri, addr=addr, environment=environment) diff --git a/openvpn_api/vpn.py b/openvpn_api/vpn.py index 6e79e8f..09e8cd0 100644 --- a/openvpn_api/vpn.py +++ b/openvpn_api/vpn.py @@ -1,10 +1,14 @@ import logging +import queue import socket import re import contextlib +import threading from typing import Optional, Generator import openvpn_status # type: ignore + +from openvpn_api import events from openvpn_api.util import errors from openvpn_api.models.state import State from openvpn_api.models.stats import ServerStats @@ -29,6 +33,17 @@ def __init__(self, host: Optional[str] = None, port: Optional[int] = None, socke self._mgmt_port = port self._type = VPNType.IP self._socket = None + self._socket_file = None + self._socket_io_lock = threading.Lock() + + self._listener_thread = None + self._writer_thread = None + + self._recv_queue = queue.Queue() + self._send_queue = queue.Queue() + + self._active_event = None + self.clear_cache() # Initialise release info and daemon state caches @property @@ -51,10 +66,21 @@ def connect(self) -> Optional[bool]: """ try: if self.type == VPNType.IP: - self._socket = socket.create_connection((self._mgmt_host, self._mgmt_port), timeout=3) + self._socket = socket.create_connection((self._mgmt_host, self._mgmt_port), timeout=None) else: self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._socket.connect(self._mgmt_socket) + + self._socket_file = self._socket.makefile("r") + + self._listener_thread = threading.Thread( + target=self._socket_listener_thread, daemon=True, name="mgmt-listener" + ) + self._writer_thread = threading.Thread(target=self._socket_writer_thread, daemon=True, name="mgmt-writer") + + self._listener_thread.start() + self._writer_thread.start() + resp = self._socket_recv() assert resp.startswith(">INFO"), "Did not get expected response from interface when opening socket." return True @@ -67,6 +93,8 @@ def disconnect(self, _quit=True) -> None: if self._socket is not None: if _quit: self._socket_send("quit\n") + + self._socket_file.close() self._socket.close() self._socket = None @@ -86,15 +114,58 @@ def connection(self) -> Generator: finally: self.disconnect() + def _socket_listener_thread(self): + """This thread handles the socket's output and handles any events before adding the output to the receive queue. + """ + active_event_lines = [] + while True: + if not self.is_connected: + break + + line = self._socket_file.readline().strip() + + if self._active_event is None: + for event in events.event_types: + if event.is_input_began(line): + active_event_lines = [] + if event.is_input_ended(line): + events.raise_event(event.parse_raw([line])) + else: + self._socket_io_lock.acquire() + self._active_event = event + active_event_lines.append(line) + break + else: + self._recv_queue.put(line) + else: + active_event_lines.append(line) + if self._active_event.is_input_ended(line): + events.raise_event(self._active_event.parse_raw(active_event_lines)) + active_event_lines = [] + self._active_event = None + self._socket_io_lock.release() + + def _socket_writer_thread(self): + while True: + if not self.is_connected: + break + + try: + data = self._send_queue.get() + self._socket_io_lock.acquire() + self._socket.send(bytes(data, "utf-8")) + finally: + self._socket_io_lock.release() + def _socket_send(self, data) -> None: """Convert data to bytes and send to socket. """ - self._socket.send(bytes(data, "utf-8")) + self._send_queue.put(data) def _socket_recv(self) -> str: """Receive bytes from socket and convert to string. """ - return self._socket.recv(4096).decode("utf-8") + return self._recv_queue.get() def send_command(self, cmd) -> Optional[str]: """Send command to management interface and fetch response. diff --git a/tests/test_event_client.py b/tests/test_event_client.py new file mode 100644 index 0000000..f01f4fd --- /dev/null +++ b/tests/test_event_client.py @@ -0,0 +1,109 @@ +import unittest + +from openvpn_api.events import client as ClientEvent +from openvpn_api.util import errors + + +class TestEventClient(unittest.TestCase): + def test_input_began(self): + self.assertTrue(ClientEvent.is_input_began(">CLIENT:CONNECT,14,23,")) + + def test_input_not_began(self): + self.assertFalse(ClientEvent.is_input_began(">BYTES:14,23")) + + def test_input_not_began_invalid_type(self): + self.assertFalse(ClientEvent.is_input_began(">CLIENT:INVALID-TYPE,14,23,")) + + def test_input_not_began_empty_input(self): + self.assertFalse(ClientEvent.is_input_began("")) + + def test_input_ended_normal(self): + self.assertTrue(ClientEvent.is_input_ended(">CLIENT:ENV,END")) + + def test_input_ended_one_liner(self): + self.assertTrue(ClientEvent.is_input_ended(">CLIENT:ADDRESS,14,3,1.1.1.1")) + + def test_empty_lines(self): + with self.assertRaises(errors.ParseError) as ctx: + ClientEvent.parse_raw([]) + self.assertEqual("Event raw input is empty.", str(ctx.exception)) + + def test_deserialize_connect_event(self): + event = ClientEvent.parse_raw( + [ + ">CLIENT:CONNECT,14,43", + ">CLIENT:ENV,common_name=test_cn", + ">CLIENT:ENV,time_unix=12343212343", + ">CLIENT:ENV,END", + ] + ) + self.assertEqual("CONNECT", event.type) + self.assertEqual(14, event.cid) + self.assertEqual(43, event.kid) + self.assertEqual({"common_name": "test_cn", "time_unix": "12343212343"}, event.environment) + + def test_deserialize_reauth_event(self): + event = ClientEvent.parse_raw( + [ + ">CLIENT:REAUTH,14,43", + ">CLIENT:ENV,common_name=test_cn", + ">CLIENT:ENV,time_unix=12343212343", + ">CLIENT:ENV,END", + ] + ) + self.assertEqual("REAUTH", event.type) + self.assertEqual(14, event.cid) + self.assertEqual(43, event.kid) + self.assertEqual({"common_name": "test_cn", "time_unix": "12343212343"}, event.environment) + + def test_deserialize_established_event(self): + event = ClientEvent.parse_raw( + [ + ">CLIENT:ESTABLISHED,14", + ">CLIENT:ENV,common_name=test_cn", + ">CLIENT:ENV,time_unix=12343212343", + ">CLIENT:ENV,END", + ] + ) + self.assertEqual("ESTABLISHED", event.type) + self.assertEqual(14, event.cid) + self.assertEqual({"common_name": "test_cn", "time_unix": "12343212343"}, event.environment) + + def test_deserialize_disconnect_event(self): + event = ClientEvent.parse_raw( + [ + ">CLIENT:DISCONNECT,14", + ">CLIENT:ENV,common_name=test_cn", + ">CLIENT:ENV,time_unix=12343212343", + ">CLIENT:ENV,END", + ] + ) + self.assertEqual("DISCONNECT", event.type) + self.assertEqual(14, event.cid) + self.assertEqual({"common_name": "test_cn", "time_unix": "12343212343"}, event.environment) + + def test_empty_environment(self): + with self.assertRaises(errors.ParseError) as ctx: + a = ClientEvent.parse_raw([">CLIENT:DISCONNECT,14", ">CLIENT:ENV,END",]) + + self.assertEqual("This event type (DISCONNECT) doesn't support empty environment.", str(ctx.exception)) + + def test_missing_environment(self): + with self.assertRaises(errors.ParseError) as ctx: + ClientEvent.parse_raw( + [">CLIENT:DISCONNECT,14",] + ) + + self.assertEqual("The raw event doesn't have an >CLIENT:ENV,END line.", str(ctx.exception)) + + def test_invalid_type(self): + with self.assertRaises(errors.ParseError) as ctx: + ClientEvent.parse_raw( + [">CLIENT:NOT-SUPPORTED,14",] + ) + + self.assertEqual( + "This event type (NOT-SUPPORTED) is not supported (Supported events: %s)" + % (ClientEvent.EVENT_TYPE_REGEXES), + str(ctx.exception), + )