Skip to content

Commit f3d19ff

Browse files
Support plugins defined as inner classes (#1318)
* Support plugins defined as inner classes * Prefer __qualname__ over __name__ for classes --------- Co-authored-by: Abhinav Singh <[email protected]>
1 parent 93f6fd6 commit f3d19ff

File tree

16 files changed

+122
-35
lines changed

16 files changed

+122
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support plugins defined as inner classes

proxy/common/plugins.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import logging
1414
import importlib
1515
import itertools
16+
from types import ModuleType
1617
from typing import Any, Dict, List, Tuple, Union, Optional
1718

1819
from .utils import text_, bytes_
@@ -75,31 +76,54 @@ def load(
7576
# this plugin_ is implementing
7677
base_klass = None
7778
for k in mro:
78-
if bytes_(k.__name__) in p:
79+
if bytes_(k.__qualname__) in p:
7980
base_klass = k
8081
break
8182
if base_klass is None:
8283
raise ValueError('%s is NOT a valid plugin' % text_(plugin_))
83-
if klass not in p[bytes_(base_klass.__name__)]:
84-
p[bytes_(base_klass.__name__)].append(klass)
85-
logger.info('Loaded plugin %s.%s', module_name, klass.__name__)
84+
if klass not in p[bytes_(base_klass.__qualname__)]:
85+
p[bytes_(base_klass.__qualname__)].append(klass)
86+
logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__)
8687
# print(p)
8788
return p
8889

8990
@staticmethod
9091
def importer(plugin: Union[bytes, type]) -> Tuple[type, str]:
9192
"""Import and returns the plugin."""
9293
if isinstance(plugin, type):
93-
return (plugin, '__main__')
94+
if inspect.isclass(plugin):
95+
return (plugin, plugin.__module__ or '__main__')
96+
raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin))
9497
plugin_ = text_(plugin.strip())
9598
assert plugin_ != ''
96-
module_name, klass_name = plugin_.rsplit(text_(DOT), 1)
97-
klass = getattr(
98-
importlib.import_module(
99-
module_name.replace(
100-
os.path.sep, text_(DOT),
101-
),
102-
),
103-
klass_name,
104-
)
99+
path = plugin_.split(text_(DOT))
100+
klass = None
101+
102+
def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]:
103+
klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT))
104+
try:
105+
klass_module = importlib.import_module(klass_module_name)
106+
except ModuleNotFoundError:
107+
return None
108+
klass_container: Union[ModuleType, type] = klass_module
109+
for klass_path_part in klass_path:
110+
try:
111+
klass_container = getattr(klass_container, klass_path_part)
112+
except AttributeError:
113+
return None
114+
if not isinstance(klass_container, type) or not inspect.isclass(klass_container):
115+
return None
116+
return klass_container
117+
118+
module_name = None
119+
for module_name_parts in range(len(path) - 1, 0, -1):
120+
module_name = '.'.join(path[0:module_name_parts])
121+
klass = locate_klass(module_name, path[module_name_parts:])
122+
if klass:
123+
break
124+
if klass is None:
125+
module_name = '__main__'
126+
klass = locate_klass(module_name, path)
127+
if klass is None or module_name is None:
128+
raise ValueError('%s is not resolvable as a plugin class' % text_(plugin))
105129
return (klass, module_name)

proxy/core/acceptor/acceptor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None:
246246
conn,
247247
addr,
248248
event_queue=self.event_queue,
249-
publisher_id=self.__class__.__name__,
249+
publisher_id=self.__class__.__qualname__,
250250
)
251251
# TODO: Move me into target method
252252
logger.debug( # pragma: no cover

proxy/core/work/fd/fd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def work(self, *args: Any) -> None:
3939
self.works[fileno].publish_event(
4040
event_name=eventNames.WORK_STARTED,
4141
event_payload={'fileno': fileno, 'addr': addr},
42-
publisher_id=self.__class__.__name__,
42+
publisher_id=self.__class__.__qualname__,
4343
)
4444
try:
4545
self.works[fileno].initialize()

proxy/core/work/work.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def shutdown(self) -> None:
8383
self.publish_event(
8484
event_name=eventNames.WORK_FINISHED,
8585
event_payload={},
86-
publisher_id=self.__class__.__name__,
86+
publisher_id=self.__class__.__qualname__,
8787
)
8888

8989
def run(self) -> None:

proxy/http/exception/http_request_rejected.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
self.reason: Optional[bytes] = reason
3737
self.headers: Optional[Dict[bytes, bytes]] = headers
3838
self.body: Optional[bytes] = body
39-
klass_name = self.__class__.__name__
39+
klass_name = self.__class__.__qualname__
4040
super().__init__(
4141
message='%s %r' % (klass_name, reason)
4242
if reason

proxy/http/exception/proxy_auth_failed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ProxyAuthenticationFailed(HttpProtocolException):
2828
incoming request doesn't present necessary credentials."""
2929

3030
def __init__(self, **kwargs: Any) -> None:
31-
super().__init__(self.__class__.__name__, **kwargs)
31+
super().__init__(self.__class__.__qualname__, **kwargs)
3232

3333
def response(self, _request: 'HttpParser') -> memoryview:
3434
return PROXY_AUTH_FAILED_RESPONSE_PKT

proxy/http/exception/proxy_conn_failed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, host: str, port: int, reason: str, **kwargs: Any):
2929
self.host: str = host
3030
self.port: int = port
3131
self.reason: str = reason
32-
super().__init__('%s %s' % (self.__class__.__name__, reason), **kwargs)
32+
super().__init__('%s %s' % (self.__class__.__qualname__, reason), **kwargs)
3333

3434
def response(self, _request: 'HttpParser') -> memoryview:
3535
return BAD_GATEWAY_RESPONSE_PKT

proxy/http/proxy/plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def name(self) -> str:
5151
5252
Defaults to name of the class. This helps plugin developers to directly
5353
access a specific plugin by its name."""
54-
return self.__class__.__name__ # pragma: no cover
54+
return self.__class__.__qualname__ # pragma: no cover
5555

5656
def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['HostPort']]:
5757
"""Resolve upstream server host to an IP address.

proxy/http/proxy/server.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def emit_request_complete(self) -> None:
883883
if self.request.method == httpMethods.POST
884884
else None,
885885
},
886-
publisher_id=self.__class__.__name__,
886+
publisher_id=self.__class__.__qualname__,
887887
)
888888

889889
def emit_response_events(self, chunk_size: int) -> None:
@@ -911,7 +911,7 @@ def emit_response_headers_complete(self) -> None:
911911
for k, v in self.response.headers.items()
912912
},
913913
},
914-
publisher_id=self.__class__.__name__,
914+
publisher_id=self.__class__.__qualname__,
915915
)
916916

917917
def emit_response_chunk_received(self, chunk_size: int) -> None:
@@ -925,7 +925,7 @@ def emit_response_chunk_received(self, chunk_size: int) -> None:
925925
'chunk_size': chunk_size,
926926
'encoded_chunk_size': chunk_size,
927927
},
928-
publisher_id=self.__class__.__name__,
928+
publisher_id=self.__class__.__qualname__,
929929
)
930930

931931
def emit_response_complete(self) -> None:
@@ -938,7 +938,7 @@ def emit_response_complete(self) -> None:
938938
event_payload={
939939
'encoded_response_size': self.response.total_size,
940940
},
941-
publisher_id=self.__class__.__name__,
941+
publisher_id=self.__class__.__qualname__,
942942
)
943943

944944
#

proxy/http/server/plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def name(self) -> str:
7272
7373
Defaults to name of the class. This helps plugin developers to directly
7474
access a specific plugin by its name."""
75-
return self.__class__.__name__ # pragma: no cover
75+
return self.__class__.__qualname__ # pragma: no cover
7676

7777
@abstractmethod
7878
def routes(self) -> List[Tuple[int, str]]:

tests/common/my_plugins/__init__.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
from typing import Any
12+
13+
from proxy.http.proxy import HttpProxyPlugin
14+
15+
16+
class MyHttpProxyPlugin(HttpProxyPlugin):
17+
def __init__(self, *args: Any, **kwargs: Any) -> None:
18+
super().__init__(*args, **kwargs)
19+
20+
21+
class OuterClass:
22+
23+
class MyHttpProxyPlugin(HttpProxyPlugin):
24+
def __init__(self, *args: Any, **kwargs: Any) -> None:
25+
super().__init__(*args, **kwargs)

tests/common/test_flags.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11-
from typing import Dict, List
11+
from typing import Any, Dict, List
1212

1313
import unittest
1414
from unittest import mock
@@ -19,6 +19,7 @@
1919
from proxy.common.utils import bytes_
2020
from proxy.common.version import __version__
2121
from proxy.common.constants import PLUGIN_HTTP_PROXY, PY2_DEPRECATION_MESSAGE
22+
from . import my_plugins
2223

2324

2425
class TestFlags(unittest.TestCase):
@@ -140,6 +141,42 @@ def test_unique_plugin_from_class(self) -> None:
140141
],
141142
})
142143

144+
def test_plugin_from_inner_class_by_type(self) -> None:
145+
self.flags = FlagParser.initialize(
146+
[], plugins=[
147+
TestFlags.MyHttpProxyPlugin,
148+
my_plugins.MyHttpProxyPlugin,
149+
my_plugins.OuterClass.MyHttpProxyPlugin,
150+
],
151+
)
152+
self.assert_plugins({
153+
'HttpProtocolHandlerPlugin': [
154+
TestFlags.MyHttpProxyPlugin,
155+
my_plugins.MyHttpProxyPlugin,
156+
my_plugins.OuterClass.MyHttpProxyPlugin,
157+
],
158+
})
159+
160+
def test_plugin_from_inner_class_by_name(self) -> None:
161+
self.flags = FlagParser.initialize(
162+
[], plugins=[
163+
b'tests.common.test_flags.TestFlags.MyHttpProxyPlugin',
164+
b'tests.common.my_plugins.MyHttpProxyPlugin',
165+
b'tests.common.my_plugins.OuterClass.MyHttpProxyPlugin',
166+
],
167+
)
168+
self.assert_plugins({
169+
'HttpProtocolHandlerPlugin': [
170+
TestFlags.MyHttpProxyPlugin,
171+
my_plugins.MyHttpProxyPlugin,
172+
my_plugins.OuterClass.MyHttpProxyPlugin,
173+
],
174+
})
175+
176+
class MyHttpProxyPlugin(HttpProxyPlugin):
177+
def __init__(self, *args: Any, **kwargs: Any) -> None:
178+
super().__init__(*args, **kwargs)
179+
143180
def test_basic_auth_flag_is_base64_encoded(self) -> None:
144181
flags = FlagParser.initialize(['--basic-auth', 'user:pass'])
145182
self.assertEqual(flags.auth_code, b'dXNlcjpwYXNz')

tests/core/test_event_dispatcher.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_empties_queue(self) -> None:
4040
request_id='1234',
4141
event_name=eventNames.WORK_STARTED,
4242
event_payload={'hello': 'events'},
43-
publisher_id=self.__class__.__name__,
43+
publisher_id=self.__class__.__qualname__,
4444
)
4545
self.dispatcher.run_once()
4646
with self.assertRaises(queue.Empty):
@@ -64,7 +64,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
6464
request_id='1234',
6565
event_name=eventNames.WORK_STARTED,
6666
event_payload={'hello': 'events'},
67-
publisher_id=self.__class__.__name__,
67+
publisher_id=self.__class__.__qualname__,
6868
)
6969
# consume
7070
self.dispatcher.run_once()
@@ -79,7 +79,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
7979
'event_timestamp': 1234567,
8080
'event_name': eventNames.WORK_STARTED,
8181
'event_payload': {'hello': 'events'},
82-
'publisher_id': self.__class__.__name__,
82+
'publisher_id': self.__class__.__qualname__,
8383
},
8484
)
8585
return relay_recv
@@ -101,7 +101,7 @@ def test_unsubscribe(self) -> None:
101101
request_id='1234',
102102
event_name=eventNames.WORK_STARTED,
103103
event_payload={'hello': 'events'},
104-
publisher_id=self.__class__.__name__,
104+
publisher_id=self.__class__.__qualname__,
105105
)
106106
self.dispatcher.run_once()
107107
with self.assertRaises(EOFError):

tests/core/test_event_queue.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
3434
request_id='1234',
3535
event_name=eventNames.WORK_STARTED,
3636
event_payload={'hello': 'events'},
37-
publisher_id=self.__class__.__name__,
37+
publisher_id=self.__class__.__qualname__,
3838
)
3939
self.assertEqual(
4040
evq.queue.get(), {
@@ -44,7 +44,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
4444
'event_timestamp': 1234567,
4545
'event_name': eventNames.WORK_STARTED,
4646
'event_payload': {'hello': 'events'},
47-
'publisher_id': self.__class__.__name__,
47+
'publisher_id': self.__class__.__qualname__,
4848
},
4949
)
5050

tests/core/test_event_subscriber.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None:
5050
request_id='1234',
5151
event_name=eventNames.WORK_STARTED,
5252
event_payload={'hello': 'events'},
53-
publisher_id=self.__class__.__name__,
53+
publisher_id=self.__class__.__qualname__,
5454
)
5555
self.dispatcher.run_once()
5656
self.subscriber.unsubscribe()
@@ -69,6 +69,6 @@ def callback(self, ev: Dict[str, Any]) -> None:
6969
'event_timestamp': 1234567,
7070
'event_name': eventNames.WORK_STARTED,
7171
'event_payload': {'hello': 'events'},
72-
'publisher_id': self.__class__.__name__,
72+
'publisher_id': self.__class__.__qualname__,
7373
},
7474
)

0 commit comments

Comments
 (0)