diff --git a/example/acct_async.py b/example/acct_async.py new file mode 100644 index 00000000..9df18cdd --- /dev/null +++ b/example/acct_async.py @@ -0,0 +1,101 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.client_async import ClientAsync +from pyrad.packet import AccountingResponse + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") +client = ClientAsync(server="127.0.0.1", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=3, debug=True, + dict=Dictionary("dictionary")) + +loop = asyncio.get_event_loop() + + +def create_request(client, user): + req = client.CreateAcctPacket(User_Name=user) + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + return req + + +def print_reply(reply): + print("Received Accounting-Response") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + +def test_acct1(enable_message_authenticator=False): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + # local_addr='127.0.0.1', + # local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + req = create_request(client, "wichert") + if enable_message_authenticator: + req.add_message_authenticator() + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccountingResponse: + print("Accounting accepted") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + +#test_acct1() +test_acct1(enable_message_authenticator=True) diff --git a/example/auth.py b/example/auth.py index 3c649878..6d14dba9 100755 --- a/example/auth.py +++ b/example/auth.py @@ -6,9 +6,11 @@ import sys import pyrad.packet -srv = Client(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", dict=Dictionary("dictionary")) +srv = Client(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary("dictionary")) -req = srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, User_Name="wichert") +req = srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, + User_Name="wichert") req["NAS-IP-Address"] = "192.168.1.10" req["NAS-Port"] = 0 diff --git a/example/auth_async.py b/example/auth_async.py index 9ce4a419..89a01981 100644 --- a/example/auth_async.py +++ b/example/auth_async.py @@ -10,9 +10,9 @@ logging.basicConfig(level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s") -client = ClientAsync(server="localhost", +client = ClientAsync(server="127.0.0.1", secret=b"Kah3choteereethiejeimaeziecumi", - timeout=4, + timeout=3, debug=True, dict=Dictionary("dictionary")) loop = asyncio.get_event_loop() @@ -31,6 +31,7 @@ def create_request(client, user): return req + def print_reply(reply): if reply.code == AccessAccept: print("Access accepted") @@ -41,6 +42,7 @@ def print_reply(reply): for i in reply.keys(): print("%s: %s" % (i, reply[i])) + def test_auth1(): global client @@ -50,13 +52,11 @@ def test_auth1(): loop.run_until_complete( asyncio.ensure_future( client.initialize_transports(enable_auth=True, - local_addr='127.0.0.1', - local_auth_port=8000, + # local_addr='127.0.0.1', + # local_auth_port=8000, enable_acct=True, enable_coa=True))) - - req = client.CreateAuthPacket(User_Name="wichert") req["NAS-IP-Address"] = "192.168.1.10" @@ -107,6 +107,7 @@ def test_auth1(): loop.close() + def test_multi_auth(): global client @@ -117,15 +118,14 @@ def test_multi_auth(): asyncio.ensure_future( client.initialize_transports(enable_auth=True, local_addr='127.0.0.1', - local_auth_port=8000, + # local_auth_port=8000, enable_acct=True, enable_coa=True))) - - reqs = [] - for i in range(255): + for i in range(150): req = create_request(client, "user%s" % i) + print('CREATE REQUEST with id %d' % req.id) future = client.SendPacket(req) reqs.append(future) @@ -145,6 +145,7 @@ def test_multi_auth(): reply = future.result() print_reply(reply) + print('INVALID RESPONSE:', client.protocol_auth.errors) # Close transports loop.run_until_complete(asyncio.ensure_future( client.deinitialize_transports())) @@ -160,5 +161,143 @@ def test_multi_auth(): loop.close() -#test_multi_auth() -test_auth1() + +def test_multi_client(): + + clients = [] + n_clients = 73 + n_req4client = 50 + reqs = [] + + global loop + + try: + for i in range(n_clients): + client = ClientAsync(server="localhost", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=4, debug=True, + dict=Dictionary("dictionary"), + loop=loop) + + clients.append(client) + + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + enable_acct=False, + enable_coa=False))) + + # Send + for j in range(n_req4client): + req = create_request(client, "user%s" % j) + print('CREATE REQUEST with id %d' % req.id) + future = client.SendPacket(req) + reqs.append(future) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + *reqs, + return_exceptions=True + ) + + )) + + for future in reqs: + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + print_reply(reply) + + client = clients.pop() + while client: + + print('INVALID RESPONSE:', client.protocol_auth.errors) + print('RETRIES:', client.protocol_auth.retries_counter) + + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + del client + if len(clients) > 0: + client = clients.pop() + else: + client = None + + print('END') + except Exception as exc: + + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + + for client in clients: + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + +def test_auth1_msg_authenticator(): + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + # local_addr='127.0.0.1', + # local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + req = create_request(client, "wichert") + req.add_message_authenticator() + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + +# test_multi_auth() +# test_auth1() +# test_multi_client() +test_auth1_msg_authenticator() diff --git a/example/pyrad.log b/example/pyrad.log deleted file mode 100644 index e69de29b..00000000 diff --git a/example/server_async.py b/example/server_async.py index 3b893dab..09992191 100644 --- a/example/server_async.py +++ b/example/server_async.py @@ -10,20 +10,22 @@ from pyrad.server import RemoteHost try: + # If available i try to use uvloop import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: pass -logging.basicConfig(level="DEBUG", +logging.basicConfig(level="INFO", format="%(asctime)s [%(levelname)-8s] %(message)s") class FakeServer(ServerAsync): - def __init__(self, loop, dictionary): + def __init__(self, loop, dictionary, enable_message_authenticator=False): ServerAsync.__init__(self, loop=loop, dictionary=dictionary, enable_pkt_verify=True, debug=True) + self.enable_message_authenticator = enable_message_authenticator def handle_auth_packet(self, protocol, pkt, addr): @@ -42,6 +44,10 @@ def handle_auth_packet(self, protocol, pkt, addr): }) reply.code = AccessAccept + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() + protocol.send_response(reply, addr) def handle_acct_packet(self, protocol, pkt, addr): @@ -52,6 +58,9 @@ def handle_acct_packet(self, protocol, pkt, addr): print("%s: %s" % (attr, pkt[attr])) reply = self.CreateReplyPacket(pkt) + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) def handle_coa_packet(self, protocol, pkt, addr): @@ -62,6 +71,8 @@ def handle_coa_packet(self, protocol, pkt, addr): print("%s: %s" % (attr, pkt[attr])) reply = self.CreateReplyPacket(pkt) + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) def handle_disconnect_packet(self, protocol, pkt, addr): @@ -74,6 +85,9 @@ def handle_disconnect_packet(self, protocol, pkt, addr): reply = self.CreateReplyPacket(pkt) # COA NAK reply.code = 45 + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) @@ -81,7 +95,8 @@ def handle_disconnect_packet(self, protocol, pkt, addr): # create server and read dictionary loop = asyncio.get_event_loop() - server = FakeServer(loop=loop, dictionary=Dictionary('dictionary')) + server = FakeServer(loop=loop, dictionary=Dictionary('dictionary'), + enable_message_authenticator=True) # add clients (address, secret, name) server.hosts["127.0.0.1"] = RemoteHost("127.0.0.1", @@ -103,6 +118,8 @@ def handle_disconnect_packet(self, protocol, pkt, addr): except KeyboardInterrupt as k: pass + print('STATS', server.stats()) + # Close transports loop.run_until_complete(asyncio.ensure_future( server.deinitialize_transports())) diff --git a/pyrad/__init__.py b/pyrad/__init__.py index 0b45dd12..5cc47ebf 100644 --- a/pyrad/__init__.py +++ b/pyrad/__init__.py @@ -41,6 +41,6 @@ __author__ = 'Christian Giese ' __url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest' __copyright__ = 'Copyright 2002-2019 Wichert Akkerman and Christian Giese. All rights reserved.' -__version__ = '2.2' +__version__ = '2.3' __all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile'] diff --git a/pyrad/client_async.py b/pyrad/client_async.py index de08917c..b16d3ca9 100644 --- a/pyrad/client_async.py +++ b/pyrad/client_async.py @@ -9,6 +9,7 @@ import six import logging import random +import traceback from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket @@ -24,6 +25,8 @@ def __init__(self, server, port, logger, self.retries = retries self.timeout = timeout self.client = client + self.errors = 0 + self.retries_counter = 0 # Map of pending requests self.pending_requests = {} @@ -40,25 +43,43 @@ async def __timeout_handler__(self): while True: + socket = self.transport.get_extra_info('socket') \ + if self.transport else None req2delete = [] - now = datetime.now() next_weak_up = self.timeout # noinspection PyShadowingBuiltins for id, req in self.pending_requests.items(): - secs = (req['send_date'] - now).seconds + now = datetime.now() + secs = (now - req['sent_date']).seconds if secs > self.timeout: if req['retries'] == self.retries: - self.logger.debug('[%s:%d] For request %d execute all retries', self.server, self.port, id) + self.logger.debug( + '[%s:%d:%d] For request %d execute all retries' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + id + ) + ) req['future'].set_exception( TimeoutError('Timeout on Reply') ) req2delete.append(id) else: # Send again packet - req['send_date'] = now + req['sent_date'] = now + req['packet'].sent_date = now req['retries'] += 1 - self.logger.debug('[%s:%d] For request %d execute retry %d', self.server, self.port, id, req['retries']) + self.retries_counter += 1 + self.logger.debug( + '[%s:%d:%d] For request %d reached %s secs. %s' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + id, secs, + 'I execute retry %d.' % req['retries'] + ) + ) + self.transport.sendto(req['packet'].RequestPacket()) elif next_weak_up > secs: next_weak_up = secs @@ -77,15 +98,17 @@ def send_packet(self, packet, future): if packet.id in self.pending_requests: raise Exception('Packet with id %d already present' % packet.id) + sent_date = datetime.now() # Store packet on pending requests map self.pending_requests[packet.id] = { 'packet': packet, - 'creation_date': datetime.now(), + 'creation_date': sent_date, 'retries': 0, 'future': future, - 'send_date': datetime.now() + 'sent_date': sent_date } + packet.sent_date = sent_date # In queue packet raw on socket buffer self.transport.sendto(packet.RequestPacket()) @@ -94,9 +117,9 @@ def connection_made(self, transport): socket = transport.get_extra_info('socket') self.logger.info( '[%s:%d] Transport created with binding in %s:%d', - self.server, self.port, - socket.getsockname()[0], - socket.getsockname()[1] + self.server, self.port, + socket.getsockname()[0], + socket.getsockname()[1] ) pre_loop = asyncio.get_event_loop() @@ -118,31 +141,83 @@ def connection_lost(self, exc): # noinspection PyUnusedLocal def datagram_received(self, data, addr): + + socket = self.transport.get_extra_info('socket') \ + if self.transport else None try: - reply = Packet(packet=data, dict=self.client.dict) - if reply and reply.id in self.pending_requests: + received_date = datetime.now() + reply = Packet(packet=data, dict=self.client.dict, + creation_date=received_date) + + if reply is not None and reply.id in self.pending_requests: req = self.pending_requests[reply.id] packet = req['packet'] - reply.dict = packet.dict reply.secret = packet.secret if packet.VerifyReply(reply, data): - req['future'].set_result(reply) - # Remove request for map - del self.pending_requests[reply.id] + + if reply.message_authenticator and not \ + reply.verify_message_authenticator( + original_authenticator=packet.authenticator): + self.logger.warn( + '[%s:%d:%d] Received invalid reply for id %d. %s' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + reply.id, + 'Invalid Message-Authenticator. Ignoring it.' + ) + ) + self.errors += 1 + else: + + req['future'].set_result(reply) + # Remove request for map + del self.pending_requests[reply.id] else: - self.logger.warn('[%s:%d] Ignore invalid reply for id %d. %s', self.server, self.port, reply.id) + self.logger.warn( + '[%s:%d:%d] Received invalid reply for id %d. %s' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + reply.id, + 'Ignoring it.' + ) + ) + self.errors += 1 else: - self.logger.warn('[%s:%d] Ignore invalid reply: %d', self.server, self.port, data) + self.logger.warn( + '[%s:%d:%d] Received invalid reply with id %d: %s.\nIgnoring it.' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + (-1, reply.id)[reply is not None], + data.hex(), + ) + ) + self.errors += 1 except Exception as exc: - self.logger.error('[%s:%d] Error on decode packet: %s', self.server, self.port, exc) + self.logger.error( + '[%s:%d:%d] Error on decode packet: %s.' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '', + (exc, '\n'.join(traceback.format_exc().splitlines()))[ + self.client.debug + ] + ) + ) async def close_transport(self): if self.transport: - self.logger.debug('[%s:%d] Closing transport...', self.server, self.port) + + socket = self.transport.get_extra_info('socket') \ + if self.transport else None + self.logger.debug( + '[%s:%d:%d] Closing transport...' % ( + self.server, self.port, + socket.getsockname()[1] if socket else '' + ) + ) self.transport.close() self.transport = None if self.timeout_future: @@ -177,7 +252,7 @@ class ClientAsync: def __init__(self, server, auth_port=1812, acct_port=1813, coa_port=3799, secret=six.b(''), dict=None, loop=None, retries=3, timeout=30, - logger_name='pyrad'): + logger_name='pyrad', debug=False): """Constructor. @@ -217,10 +292,13 @@ def __init__(self, server, auth_port=1812, acct_port=1813, self.protocol_coa = None self.coa_port = coa_port + self.debug = debug + async def initialize_transports(self, enable_acct=False, enable_auth=False, enable_coa=False, local_addr=None, local_auth_port=None, - local_acct_port=None, local_coa_port=None): + local_acct_port=None, local_coa_port=None, + reuse_address=True, reuse_port=True): task_list = [] @@ -241,7 +319,7 @@ async def initialize_transports(self, enable_acct=False, acct_connect = self.loop.create_datagram_endpoint( self.protocol_acct, - reuse_address=True, reuse_port=True, + reuse_port=reuse_port, remote_addr=(self.server, self.acct_port), local_addr=bind_addr ) @@ -261,7 +339,7 @@ async def initialize_transports(self, enable_acct=False, auth_connect = self.loop.create_datagram_endpoint( self.protocol_auth, - reuse_address=True, reuse_port=True, + reuse_port=True, remote_addr=(self.server, self.auth_port), local_addr=bind_addr ) @@ -398,9 +476,14 @@ def SendPacket(self, pkt): if not self.protocol_acct: raise Exception('Transport not initialized') + self.protocol_acct.send_packet(pkt, ans) + elif isinstance(pkt, CoAPacket): if not self.protocol_coa: raise Exception('Transport not initialized') + + self.protocol_coa.send_packet(pkt, ans) + else: raise Exception('Unsupported packet') diff --git a/pyrad/packet.py b/pyrad/packet.py index 82902aff..d61bba99 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -5,10 +5,18 @@ # A RADIUS packet as defined in RFC 2138 from collections import OrderedDict +from datetime import datetime import struct import random # Hmac needed for Message-Authenticator import hmac + +import sys +if sys.version_info >= (3, 0): + hmac_new = lambda *x, **y: hmac.new(*x, digestmod='MD5', **y) +else: + hmac_new = hmac.new + try: import hashlib md5_constructor = hashlib.md5 @@ -63,6 +71,7 @@ class Packet(OrderedDict): """ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, + creation_date=datetime.utcnow(), **attributes): """Constructor @@ -91,6 +100,7 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, raise TypeError('authenticator must be a binary string') self.authenticator = authenticator self.message_authenticator = None + self.creation_date = creation_date if 'dict' in attributes: self.dict = attributes['dict'] @@ -100,11 +110,14 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, if 'message_authenticator' in attributes: self.message_authenticator = attributes['message_authenticator'] + if 'creation_date' in attributes: + self.creation_date = attributes['creation_date'] + self.sent_date = None for (key, value) in attributes.items(): if key in [ 'dict', 'fd', 'packet', - 'message_authenticator', + 'message_authenticator', 'creation_date' ]: continue key = key.replace('_', '-') @@ -128,7 +141,7 @@ def get_message_authenticator(self): return self.message_authenticator def _refresh_message_authenticator(self): - hmac_constructor = hmac.new(self.secret) + hmac_constructor = hmac_new(self.secret) # Maintain a zero octets content for md5 and hmac calculation. self['Message-Authenticator'] = 16 * six.b('\00') @@ -182,7 +195,7 @@ def verify_message_authenticator(self, secret=None, header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) - hmac_constructor = hmac.new(key) + hmac_constructor = hmac_new(key) hmac_constructor.update(header) if self.code in (AccountingRequest, DisconnectRequest, CoARequest, AccountingResponse): @@ -329,7 +342,7 @@ def __delitem__(self, key): def __setitem__(self, key, item): if isinstance(key, six.string_types): - (key, item) = self._EncodeKeyValues(key, item) + (key, item) = self._EncodeKeyValues(key, [item]) OrderedDict.__setitem__(self, key, item) else: OrderedDict.__setitem__(self, key, item) @@ -561,9 +574,20 @@ def SaltCrypt(self, value): if self.authenticator is None: # self.authenticator = self.CreateAuthenticator() self.authenticator = 16 * six.b('\x00') + if six.PY3: + random_value = 32768 + random_generator.randrange(0, 32767) + salt_raw = struct.pack('!H', random_value ) + salt_str = chr(salt_raw[0]) + chr(salt_raw[0]) + salt = six.b(salt_str) + result = salt + else: + random_value = random_generator.randrange(0, 65535) + salt = struct.pack('!H', random_value ) + salt = chr(ord(salt[0]) | 1 << 7)+salt[1] + result = six.b(salt) - salt = struct.pack('!H', random_generator.randrange(0, 65535)) - salt = chr(ord(salt[0]) | 1 << 7)+salt[1] + #salt = struct.pack('!H', random_generator.randrange(0, 65535)) + #salt = chr(ord(salt[0]) | 1 << 7)+salt[1] length = struct.pack("B", len(value)) buf = length + value @@ -734,24 +758,21 @@ def VerifyChapPasswd(self, userpwd): chapid = chap_password[0] if six.PY3: - chapid = chr(chapid).encode('utf-8') + chapid = six.b(str(chr(chapid))) password = chap_password[1:] challenge = self.authenticator if 'CHAP-Challenge' in self: challenge = self['CHAP-Challenge'][0] - return password == md5_constructor(chapid + userpwd + challenge).digest() - def VerifyAuthRequest(self): - """Verify request authenticator. + c = "%s%s%s" % (chapid, userpwd, challenge) + md5 = md5_constructor( + chapid + + userpwd + + challenge + ).digest() - :return: True if verification failed else False - :rtype: boolean - """ - assert(self.raw_packet) - hash = md5_constructor(self.raw_packet[0:4] + 16 * six.b('\x00') + - self.raw_packet[20:] + self.secret).digest() - return hash == self.authenticator + return password == md5 class AcctPacket(Packet): @@ -820,10 +841,7 @@ def RequestPacket(self): self.authenticator = md5_constructor(header[0:4] + 16 * six.b('\x00') + attr + self.secret).digest() - ans = header + self.authenticator + attr - - return ans - + return header + self.authenticator + attr class CoAPacket(Packet): """RADIUS CoA packets. This class is a specialization diff --git a/pyrad/server_async.py b/pyrad/server_async.py index 070754d6..a054750d 100644 --- a/pyrad/server_async.py +++ b/pyrad/server_async.py @@ -36,6 +36,7 @@ def __init__(self, ip, port, logger, server, server_type, hosts, self.hosts = hosts self.server_type = server_type self.request_callback = request_callback + self.requests = 0 def connection_made(self, transport): self.transport = transport @@ -48,73 +49,141 @@ def connection_lost(self, exc): self.logger.info('[%s:%d] Transport closed', self.ip, self.port) def send_response(self, reply, addr): + if self.server.debug: + self.logger.info( + '[%s:%d] Sending Response to %s packet: %s' % ( + self.ip, self.port, addr, reply.ReplyPacket().hex() + ) + ) self.transport.sendto(reply.ReplyPacket(), addr) + def __get_remote_host__(self, addr): + ans = None + if addr in self.hosts.keys(): + ans = self.hosts[addr] + return ans + def datagram_received(self, data, addr): - self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) + self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, + self.port, len(data), addr) receive_date = datetime.utcnow() - if addr[0] in self.hosts: - remote_host = self.hosts[addr[0]] - elif '0.0.0.0' in self.hosts: - remote_host = self.hosts['0.0.0.0'].secret - else: - self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr) - return + remote_host = self.__get_remote_host__(addr[0]) - try: - self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex()) - req = Packet(packet=data, dict=self.server.dict) - except Exception as exc: - self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc) - return + if remote_host: - try: - if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK): - raise ServerPacketError('Invalid response packet %d' % req.code) - - elif self.server_type == ServerType.Auth: - if req.code != AccessRequest: - raise ServerPacketError('Received non-auth packet on auth port') - req = AuthPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyAuthRequest(): - raise PacketError('Packet verification failed') - - elif self.server_type == ServerType.Coa: - if req.code != DisconnectRequest and req.code != CoARequest: - raise ServerPacketError('Received non-coa packet on coa port') - req = CoAPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyCoARequest(): - raise PacketError('Packet verification failed') - - elif self.server_type == ServerType.Acct: - - if req.code != AccountingRequest: - raise ServerPacketError('Received non-acct packet on acct port') - req = AcctPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) - if self.server.enable_pkt_verify: - if req.VerifyAcctRequest(): - raise PacketError('Packet verification failed') - - # Call request callback - self.request_callback(self, req, addr) - except Exception as exc: - if self.server.debug: - self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr) - else: - self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc) + try: + if self.server.debug: + self.logger.info( + '[%s:%d] Received from %s packet: %s.' % ( + self.ip, self.port, addr, data.hex() + ) + ) + req = Packet(packet=data, dict=self.server.dict) + + except Exception as exc: + self.logger.error( + '[%s:%d] Error on decode packet: %s. Ignore it.' % ( + self.ip, self.port, exc + ) + ) + req = None + + if not req: + return + + try: + if req.code in ( + AccountingResponse, + AccessAccept, + AccessReject, + CoANAK, + CoAACK, + DisconnectNAK, + DisconnectACK): + raise ServerPacketError('Invalid response packet %d' % + req.code) + + elif self.server_type == ServerType.Auth: + + if req.code != AccessRequest: + raise ServerPacketError( + 'Received not-authentication packet ' + 'on authentication port') + req = AuthPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + + if self.server.enable_pkt_verify and \ + req.message_authenticator and \ + not req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) + + elif self.server_type == ServerType.Coa: + + if req.code != DisconnectRequest and \ + req.code != CoARequest: + raise ServerPacketError( + 'Received not-coa packet on coa port' + ) + req = CoAPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if not req.VerifyCoARequest(): + raise PacketError('Packet verification failed') + if req.message_authenticator and \ + not req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) + + elif self.server_type == ServerType.Acct: + + if req.code != AccountingRequest: + raise ServerPacketError( + 'Received not-accounting packet on ' + 'accounting port' + ) + req = AcctPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + + if self.server.enable_pkt_verify: + if not req.VerifyAcctRequest(): + raise PacketError('Packet verification failed') + if req.message_authenticator and not \ + req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) + + # Call request callback + self.request_callback(self, req, addr) + + self.requests += 1 + + except Exception as e: + self.logger.error( + '[%s:%d] Unexpected error for packet from %s: %s' % ( + self.ip, self.port, addr, + (e, '\n'.join(traceback.format_exc().splitlines()))[ + self.server.debug + ] + ) + ) + + else: + self.logger.error('[%s:%d] Drop package from unknown source %s', + self.ip, self.port, addr) process_date = datetime.utcnow() - self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000) + self.logger.debug('[%s:%d] Request from %s processed in %d ms', + self.ip, self.port, addr, + (process_date-receive_date).microseconds/1000) def error_received(self, exc): self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc) @@ -218,7 +287,8 @@ def CreateReplyPacket(pkt, **attributes): async def initialize_transports(self, enable_acct=False, enable_auth=False, enable_coa=False, - addresses=None): + addresses=None, reuse_address=True, + reuse_port=True): task_list = [] @@ -243,7 +313,7 @@ async def initialize_transports(self, enable_acct=False, bind_addr = (addr, self.acct_port) acct_connect = self.loop.create_datagram_endpoint( protocol_acct, - reuse_address=True, reuse_port=True, + reuse_address=reuse_address, reuse_port=reuse_port, local_addr=bind_addr ) self.acct_protocols.append(protocol_acct) @@ -262,7 +332,7 @@ async def initialize_transports(self, enable_acct=False, auth_connect = self.loop.create_datagram_endpoint( protocol_auth, - reuse_address=True, reuse_port=True, + reuse_address=reuse_address, reuse_port=reuse_port, local_addr=bind_addr ) self.auth_protocols.append(protocol_auth) @@ -281,7 +351,7 @@ async def initialize_transports(self, enable_acct=False, coa_connect = self.loop.create_datagram_endpoint( protocol_coa, - reuse_address=True, reuse_port=True, + reuse_address=reuse_address, reuse_port=reuse_port, local_addr=bind_addr ) self.coa_protocols.append(protocol_coa) @@ -295,6 +365,18 @@ async def initialize_transports(self, enable_acct=False, loop=self.loop ) + def stats(self): + ans = {} + + for proto in self.coa_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + for proto in self.auth_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + for proto in self.acct_protocols: + ans['%s-%s' % (proto.ip, proto.port)] = proto.requests + + return ans + # noinspection SpellCheckingInspection async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True):