|
13 | 13 | import logging
|
14 | 14 | import importlib
|
15 | 15 | import itertools
|
| 16 | +from types import ModuleType |
16 | 17 | from typing import Any, Dict, List, Tuple, Union, Optional
|
17 | 18 |
|
18 | 19 | from .utils import text_, bytes_
|
@@ -75,31 +76,54 @@ def load(
|
75 | 76 | # this plugin_ is implementing
|
76 | 77 | base_klass = None
|
77 | 78 | for k in mro:
|
78 |
| - if bytes_(k.__name__) in p: |
| 79 | + if bytes_(k.__qualname__) in p: |
79 | 80 | base_klass = k
|
80 | 81 | break
|
81 | 82 | if base_klass is None:
|
82 | 83 | raise ValueError('%s is NOT a valid plugin' % text_(plugin_))
|
83 |
| - if klass not in p[bytes_(base_klass.__name__)]: |
84 |
| - p[bytes_(base_klass.__name__)].append(klass) |
85 |
| - logger.info('Loaded plugin %s.%s', module_name, klass.__name__) |
| 84 | + if klass not in p[bytes_(base_klass.__qualname__)]: |
| 85 | + p[bytes_(base_klass.__qualname__)].append(klass) |
| 86 | + logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__) |
86 | 87 | # print(p)
|
87 | 88 | return p
|
88 | 89 |
|
89 | 90 | @staticmethod
|
90 | 91 | def importer(plugin: Union[bytes, type]) -> Tuple[type, str]:
|
91 | 92 | """Import and returns the plugin."""
|
92 | 93 | if isinstance(plugin, type):
|
93 |
| - return (plugin, '__main__') |
| 94 | + if inspect.isclass(plugin): |
| 95 | + return (plugin, plugin.__module__ or '__main__') |
| 96 | + raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin)) |
94 | 97 | plugin_ = text_(plugin.strip())
|
95 | 98 | assert plugin_ != ''
|
96 |
| - module_name, klass_name = plugin_.rsplit(text_(DOT), 1) |
97 |
| - klass = getattr( |
98 |
| - importlib.import_module( |
99 |
| - module_name.replace( |
100 |
| - os.path.sep, text_(DOT), |
101 |
| - ), |
102 |
| - ), |
103 |
| - klass_name, |
104 |
| - ) |
| 99 | + path = plugin_.split(text_(DOT)) |
| 100 | + klass = None |
| 101 | + |
| 102 | + def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]: |
| 103 | + klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT)) |
| 104 | + try: |
| 105 | + klass_module = importlib.import_module(klass_module_name) |
| 106 | + except ModuleNotFoundError: |
| 107 | + return None |
| 108 | + klass_container: Union[ModuleType, type] = klass_module |
| 109 | + for klass_path_part in klass_path: |
| 110 | + try: |
| 111 | + klass_container = getattr(klass_container, klass_path_part) |
| 112 | + except AttributeError: |
| 113 | + return None |
| 114 | + if not isinstance(klass_container, type) or not inspect.isclass(klass_container): |
| 115 | + return None |
| 116 | + return klass_container |
| 117 | + |
| 118 | + module_name = None |
| 119 | + for module_name_parts in range(len(path) - 1, 0, -1): |
| 120 | + module_name = '.'.join(path[0:module_name_parts]) |
| 121 | + klass = locate_klass(module_name, path[module_name_parts:]) |
| 122 | + if klass: |
| 123 | + break |
| 124 | + if klass is None: |
| 125 | + module_name = '__main__' |
| 126 | + klass = locate_klass(module_name, path) |
| 127 | + if klass is None or module_name is None: |
| 128 | + raise ValueError('%s is not resolvable as a plugin class' % text_(plugin)) |
105 | 129 | return (klass, module_name)
|
0 commit comments