diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c1b993fef60..2595ed35645 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,7 +72,11 @@ jobs: - name: Install R ${{ matrix.ver }} system dependencies if: matrix.os == 'ubuntu-22.04' - run: sudo apt-get update; sudo apt-get install -y libcurl4-openssl-dev qpdf libgit2-dev libharfbuzz-dev libfribidi-dev + run: | + sudo apt-get update + sudo apt-get install -y \ + libcurl4-openssl-dev qpdf libgit2-dev libharfbuzz-dev libfribidi-dev \ + libfreetype6-dev libpng-dev libtiff5-dev libjpeg-dev libwebp-dev pkg-config - name: Install R ${{ matrix.ver }} Rlang dependencies run: | diff --git a/metaflow/_vendor/__init__.py b/metaflow/_vendor/__init__.py index ae7b11a6298..30011733fda 100644 --- a/metaflow/_vendor/__init__.py +++ b/metaflow/_vendor/__init__.py @@ -1,10 +1,10 @@ """ -metaflow._vendor is for vendoring dependencies of metaflow. Files -inside of metaflow._vendor should be considered immutable and -should only be updated to versions from upstream. +metaflow._vendor is for vendoring dependencies of metaflow. Files +inside of metaflow._vendor should be considered immutable and +should only be updated to versions from upstream. This folder is generated by `python vendor.py` -If you would like to debundle the vendored dependencies, please +If you would like to debundle the vendored dependencies, please reach out to the maintainers at chat.metaflow.org """ diff --git a/metaflow/_vendor/click/__init__.py b/metaflow/_vendor/click/__init__.py index 2b6008f2dd4..a098e317ec4 100644 --- a/metaflow/_vendor/click/__init__.py +++ b/metaflow/_vendor/click/__init__.py @@ -4,6 +4,7 @@ around a simple API that does not come with too much magic and is composable. """ + from .core import Argument from .core import BaseCommand from .core import Command diff --git a/metaflow/_vendor/click/_compat.py b/metaflow/_vendor/click/_compat.py index 60cb115bc50..7aec09977e6 100644 --- a/metaflow/_vendor/click/_compat.py +++ b/metaflow/_vendor/click/_compat.py @@ -270,7 +270,6 @@ def filename_to_ui(value): value = value.decode(get_filesystem_encoding(), "replace") return value - else: import io @@ -725,7 +724,6 @@ def get_winterm_size(): ).srWindow return win.Right - win.Left, win.Bottom - win.Top - else: def _get_argv_encoding(): diff --git a/metaflow/_vendor/click/_termui_impl.py b/metaflow/_vendor/click/_termui_impl.py index 88bec37701c..cd9e3d016cd 100644 --- a/metaflow/_vendor/click/_termui_impl.py +++ b/metaflow/_vendor/click/_termui_impl.py @@ -459,7 +459,9 @@ def edit_file(self, filename): environ = None try: c = subprocess.Popen( - '{} "{}"'.format(editor, filename), env=environ, shell=True, + '{} "{}"'.format(editor, filename), + env=environ, + shell=True, ) exit_code = c.wait() if exit_code != 0: @@ -563,11 +565,11 @@ def _unquote_file(url): def _translate_ch_to_exc(ch): - if ch == u"\x03": + if ch == "\x03": raise KeyboardInterrupt() - if ch == u"\x04" and not WIN: # Unix-like, Ctrl+D + if ch == "\x04" and not WIN: # Unix-like, Ctrl+D raise EOFError() - if ch == u"\x1a" and WIN: # Windows, Ctrl+Z + if ch == "\x1a" and WIN: # Windows, Ctrl+Z raise EOFError() @@ -614,14 +616,13 @@ def getchar(echo): func = msvcrt.getwch rv = func() - if rv in (u"\x00", u"\xe0"): + if rv in ("\x00", "\xe0"): # \x00 and \xe0 are control characters that indicate special key, # see above. rv += func() _translate_ch_to_exc(rv) return rv - else: import tty import termios diff --git a/metaflow/_vendor/click/core.py b/metaflow/_vendor/click/core.py index f58bf26d2f9..2dd19bd82c6 100644 --- a/metaflow/_vendor/click/core.py +++ b/metaflow/_vendor/click/core.py @@ -1463,6 +1463,7 @@ class Parameter(object): parameter. The old callback format will still work, but it will raise a warning to give you a chance to migrate the code easier. """ + param_type_name = "parameter" def __init__( diff --git a/metaflow/_vendor/click/globals.py b/metaflow/_vendor/click/globals.py index 1649f9a0bfb..feac2e91f5d 100644 --- a/metaflow/_vendor/click/globals.py +++ b/metaflow/_vendor/click/globals.py @@ -36,7 +36,7 @@ def pop_context(): def resolve_color_default(color=None): - """"Internal helper to get the default value of the color flag. If a + """ "Internal helper to get the default value of the color flag. If a value is passed it's returned unchanged, otherwise it's looked up from the current context. """ diff --git a/metaflow/_vendor/click/utils.py b/metaflow/_vendor/click/utils.py index 79265e732d4..1423596c664 100644 --- a/metaflow/_vendor/click/utils.py +++ b/metaflow/_vendor/click/utils.py @@ -234,9 +234,9 @@ def echo(message=None, file=None, nl=True, err=False, color=None): message = text_type(message) if nl: - message = message or u"" + message = message or "" if isinstance(message, text_type): - message += u"\n" + message += "\n" else: message += b"\n" diff --git a/metaflow/_vendor/imghdr/__init__.py b/metaflow/_vendor/imghdr/__init__.py index c448ffafe4d..f0739aba4e6 100644 --- a/metaflow/_vendor/imghdr/__init__.py +++ b/metaflow/_vendor/imghdr/__init__.py @@ -11,13 +11,15 @@ f"{__name__} was removed in Python 3.13. " f"Please be aware that you are currently NOT using standard '{__name__}', " f"but instead a separately installed 'standard-{__name__}'.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) -#-------------------------# +# -------------------------# # Recognize image headers # -#-------------------------# +# -------------------------# + def what(file, h=None): """Return the type of image contained in a file or byte stream.""" @@ -25,7 +27,7 @@ def what(file, h=None): try: if h is None: if isinstance(file, (str, PathLike)): - f = open(file, 'rb') + f = open(file, "rb") h = f.read(32) else: location = file.tell() @@ -36,151 +38,181 @@ def what(file, h=None): if res: return res finally: - if f: f.close() + if f: + f.close() return None -#---------------------------------# +# ---------------------------------# # Subroutines per image file type # -#---------------------------------# +# ---------------------------------# tests = [] + def test_jpeg(h, f): """Test for JPEG data with JFIF or Exif markers; and raw JPEG.""" - if h[6:10] in (b'JFIF', b'Exif'): - return 'jpeg' - elif h[:4] == b'\xff\xd8\xff\xdb': - return 'jpeg' + if h[6:10] in (b"JFIF", b"Exif"): + return "jpeg" + elif h[:4] == b"\xff\xd8\xff\xdb": + return "jpeg" + tests.append(test_jpeg) + def test_png(h, f): """Verify if the image is a PNG.""" - if h.startswith(b'\211PNG\r\n\032\n'): - return 'png' + if h.startswith(b"\211PNG\r\n\032\n"): + return "png" + tests.append(test_png) + def test_gif(h, f): """Verify if the image is a GIF ('87 or '89 variants).""" - if h[:6] in (b'GIF87a', b'GIF89a'): - return 'gif' + if h[:6] in (b"GIF87a", b"GIF89a"): + return "gif" + tests.append(test_gif) + def test_tiff(h, f): """Verify if the image is a TIFF (can be in Motorola or Intel byte order).""" - if h[:2] in (b'MM', b'II'): - return 'tiff' + if h[:2] in (b"MM", b"II"): + return "tiff" + tests.append(test_tiff) + def test_rgb(h, f): """test for the SGI image library.""" - if h.startswith(b'\001\332'): - return 'rgb' + if h.startswith(b"\001\332"): + return "rgb" + tests.append(test_rgb) + def test_pbm(h, f): """Verify if the image is a PBM (portable bitmap).""" - if len(h) >= 3 and \ - h[0] == ord(b'P') and h[1] in b'14' and h[2] in b' \t\n\r': - return 'pbm' + if len(h) >= 3 and h[0] == ord(b"P") and h[1] in b"14" and h[2] in b" \t\n\r": + return "pbm" + tests.append(test_pbm) + def test_pgm(h, f): """Verify if the image is a PGM (portable graymap).""" - if len(h) >= 3 and \ - h[0] == ord(b'P') and h[1] in b'25' and h[2] in b' \t\n\r': - return 'pgm' + if len(h) >= 3 and h[0] == ord(b"P") and h[1] in b"25" and h[2] in b" \t\n\r": + return "pgm" + tests.append(test_pgm) + def test_ppm(h, f): """Verify if the image is a PPM (portable pixmap).""" - if len(h) >= 3 and \ - h[0] == ord(b'P') and h[1] in b'36' and h[2] in b' \t\n\r': - return 'ppm' + if len(h) >= 3 and h[0] == ord(b"P") and h[1] in b"36" and h[2] in b" \t\n\r": + return "ppm" + tests.append(test_ppm) + def test_rast(h, f): """test for the Sun raster file.""" - if h.startswith(b'\x59\xA6\x6A\x95'): - return 'rast' + if h.startswith(b"\x59\xa6\x6a\x95"): + return "rast" + tests.append(test_rast) + def test_xbm(h, f): """Verify if the image is a X bitmap (X10 or X11).""" - if h.startswith(b'#define '): - return 'xbm' + if h.startswith(b"#define "): + return "xbm" + tests.append(test_xbm) + def test_bmp(h, f): """Verify if the image is a BMP file.""" - if h.startswith(b'BM'): - return 'bmp' + if h.startswith(b"BM"): + return "bmp" + tests.append(test_bmp) + def test_webp(h, f): """Verify if the image is a WebP.""" - if h.startswith(b'RIFF') and h[8:12] == b'WEBP': - return 'webp' + if h.startswith(b"RIFF") and h[8:12] == b"WEBP": + return "webp" + tests.append(test_webp) + def test_exr(h, f): """verify is the image ia a OpenEXR fileOpenEXR.""" - if h.startswith(b'\x76\x2f\x31\x01'): - return 'exr' + if h.startswith(b"\x76\x2f\x31\x01"): + return "exr" + tests.append(test_exr) -#--------------------# +# --------------------# # Small test program # -#--------------------# +# --------------------# + def test(): import sys + recursive = 0 - if sys.argv[1:] and sys.argv[1] == '-r': + if sys.argv[1:] and sys.argv[1] == "-r": del sys.argv[1:2] recursive = 1 try: if sys.argv[1:]: testall(sys.argv[1:], recursive, 1) else: - testall(['.'], recursive, 1) + testall(["."], recursive, 1) except KeyboardInterrupt: - sys.stderr.write('\n[Interrupted]\n') + sys.stderr.write("\n[Interrupted]\n") sys.exit(1) + def testall(list, recursive, toplevel): import sys import os + for filename in list: if os.path.isdir(filename): - print(filename + '/:', end=' ') + print(filename + "/:", end=" ") if recursive or toplevel: - print('recursing down:') + print("recursing down:") import glob - names = glob.glob(os.path.join(glob.escape(filename), '*')) + + names = glob.glob(os.path.join(glob.escape(filename), "*")) testall(names, recursive, 0) else: - print('*** directory (use -r) ***') + print("*** directory (use -r) ***") else: - print(filename + ':', end=' ') + print(filename + ":", end=" ") sys.stdout.flush() try: print(what(filename)) except OSError: - print('*** not found ***') + print("*** not found ***") + -if __name__ == '__main__': +if __name__ == "__main__": test() diff --git a/metaflow/_vendor/importlib_metadata/__init__.py b/metaflow/_vendor/importlib_metadata/__init__.py index d6c84fb70e9..436dcc3f8ad 100644 --- a/metaflow/_vendor/importlib_metadata/__init__.py +++ b/metaflow/_vendor/importlib_metadata/__init__.py @@ -33,18 +33,18 @@ __all__ = [ - 'Distribution', - 'DistributionFinder', - 'PackageMetadata', - 'PackageNotFoundError', - 'distribution', - 'distributions', - 'entry_points', - 'files', - 'metadata', - 'packages_distributions', - 'requires', - 'version', + "Distribution", + "DistributionFinder", + "PackageMetadata", + "PackageNotFoundError", + "distribution", + "distributions", + "entry_points", + "files", + "metadata", + "packages_distributions", + "requires", + "version", ] @@ -114,15 +114,15 @@ def read(text, filter_=None): lines = filter(filter_, map(str.strip, text.splitlines())) name = None for value in lines: - section_match = value.startswith('[') and value.endswith(']') + section_match = value.startswith("[") and value.endswith("]") if section_match: - name = value.strip('[]') + name = value.strip("[]") continue yield Pair(name, value) @staticmethod def valid(line): - return line and not line.startswith('#') + return line and not line.startswith("#") class DeprecatedTuple: @@ -160,9 +160,9 @@ class EntryPoint(DeprecatedTuple): """ pattern = re.compile( - r'(?P[\w.]+)\s*' - r'(:\s*(?P[\w.]+))?\s*' - r'(?P\[.*\])?\s*$' + r"(?P[\w.]+)\s*" + r"(:\s*(?P[\w.]+))?\s*" + r"(?P\[.*\])?\s*$" ) """ A regular expression describing the syntax for an entry point, @@ -180,7 +180,7 @@ class EntryPoint(DeprecatedTuple): following the attr, and following any extras. """ - dist: Optional['Distribution'] = None + dist: Optional["Distribution"] = None def __init__(self, name, value, group): vars(self).update(name=name, value=value, group=group) @@ -191,24 +191,24 @@ def load(self): return the named object. """ match = self.pattern.match(self.value) - module = import_module(match.group('module')) - attrs = filter(None, (match.group('attr') or '').split('.')) + module = import_module(match.group("module")) + attrs = filter(None, (match.group("attr") or "").split(".")) return functools.reduce(getattr, attrs, module) @property def module(self): match = self.pattern.match(self.value) - return match.group('module') + return match.group("module") @property def attr(self): match = self.pattern.match(self.value) - return match.group('attr') + return match.group("attr") @property def extras(self): match = self.pattern.match(self.value) - return list(re.finditer(r'\w+', match.group('extras') or '')) + return list(re.finditer(r"\w+", match.group("extras") or "")) def _for(self, dist): vars(self).update(dist=dist) @@ -243,8 +243,8 @@ def __setattr__(self, name, value): def __repr__(self): return ( - f'EntryPoint(name={self.name!r}, value={self.value!r}, ' - f'group={self.group!r})' + f"EntryPoint(name={self.name!r}, value={self.value!r}, " + f"group={self.group!r})" ) def __hash__(self): @@ -298,16 +298,16 @@ def wrapped(self, *args, **kwargs): return wrapped for method_name in [ - '__setitem__', - '__delitem__', - 'append', - 'reverse', - 'extend', - 'pop', - 'remove', - '__iadd__', - 'insert', - 'sort', + "__setitem__", + "__delitem__", + "append", + "reverse", + "extend", + "pop", + "remove", + "__iadd__", + "insert", + "sort", ]: locals()[method_name] = _wrap_deprecated_method(method_name) @@ -382,7 +382,7 @@ def _from_text_for(cls, text, dist): def _from_text(text): return ( EntryPoint(name=item.value.name, value=item.value.value, group=item.name) - for item in Sectioned.section_pairs(text or '') + for item in Sectioned.section_pairs(text or "") ) @@ -449,7 +449,7 @@ class SelectableGroups(Deprecated, dict): @classmethod def load(cls, eps): - by_group = operator.attrgetter('group') + by_group = operator.attrgetter("group") ordered = sorted(eps, key=by_group) grouped = itertools.groupby(ordered, by_group) return cls((group, EntryPoints(eps)) for group, eps in grouped) @@ -484,12 +484,12 @@ def select(self, **params): class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" - def read_text(self, encoding='utf-8'): + def read_text(self, encoding="utf-8"): with self.locate().open(encoding=encoding) as stream: return stream.read() def read_binary(self): - with self.locate().open('rb') as stream: + with self.locate().open("rb") as stream: return stream.read() def locate(self): @@ -499,10 +499,10 @@ def locate(self): class FileHash: def __init__(self, spec): - self.mode, _, self.value = spec.partition('=') + self.mode, _, self.value = spec.partition("=") def __repr__(self): - return f'' + return f"" class Distribution: @@ -551,7 +551,7 @@ def discover(cls, **kwargs): :context: A ``DistributionFinder.Context`` object. :return: Iterable of Distribution objects for all packages. """ - context = kwargs.pop('context', None) + context = kwargs.pop("context", None) if context and kwargs: raise ValueError("cannot accept context and kwargs") context = context or DistributionFinder.Context(**kwargs) @@ -572,12 +572,12 @@ def at(path): def _discover_resolvers(): """Search the meta_path for resolvers.""" declared = ( - getattr(finder, 'find_distributions', None) for finder in sys.meta_path + getattr(finder, "find_distributions", None) for finder in sys.meta_path ) return filter(None, declared) @classmethod - def _local(cls, root='.'): + def _local(cls, root="."): from pep517 import build, meta system = build.compat_system(root) @@ -596,19 +596,19 @@ def metadata(self) -> _meta.PackageMetadata: metadata. See PEP 566 for details. """ text = ( - self.read_text('METADATA') - or self.read_text('PKG-INFO') + self.read_text("METADATA") + or self.read_text("PKG-INFO") # This last clause is here to support old egg-info files. Its # effect is to just end up using the PathDistribution's self._path # (which points to the egg-info file) attribute unchanged. - or self.read_text('') + or self.read_text("") ) return _adapters.Message(email.message_from_string(text)) @property def name(self): """Return the 'Name' metadata for the distribution package.""" - return self.metadata['Name'] + return self.metadata["Name"] @property def _normalized_name(self): @@ -618,11 +618,11 @@ def _normalized_name(self): @property def version(self): """Return the 'Version' metadata for the distribution package.""" - return self.metadata['Version'] + return self.metadata["Version"] @property def entry_points(self): - return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self) + return EntryPoints._from_text_for(self.read_text("entry_points.txt"), self) @property def files(self): @@ -653,7 +653,7 @@ def _read_files_distinfo(self): """ Read the lines of RECORD """ - text = self.read_text('RECORD') + text = self.read_text("RECORD") return text and text.splitlines() def _read_files_egginfo(self): @@ -661,7 +661,7 @@ def _read_files_egginfo(self): SOURCES.txt might contain literal commas, so wrap each line in quotes. """ - text = self.read_text('SOURCES.txt') + text = self.read_text("SOURCES.txt") return text and map('"{}"'.format, text.splitlines()) @property @@ -671,10 +671,10 @@ def requires(self): return reqs and list(reqs) def _read_dist_info_reqs(self): - return self.metadata.get_all('Requires-Dist') + return self.metadata.get_all("Requires-Dist") def _read_egg_info_reqs(self): - source = self.read_text('requires.txt') + source = self.read_text("requires.txt") return source and self._deps_from_requires_text(source) @classmethod @@ -697,12 +697,12 @@ def make_condition(name): return name and f'extra == "{name}"' def quoted_marker(section): - section = section or '' - extra, sep, markers = section.partition(':') + section = section or "" + extra, sep, markers = section.partition(":") if extra and markers: - markers = f'({markers})' + markers = f"({markers})" conditions = list(filter(None, [markers, make_condition(extra)])) - return '; ' + ' and '.join(conditions) if conditions else '' + return "; " + " and ".join(conditions) if conditions else "" def url_req_space(req): """ @@ -710,7 +710,7 @@ def url_req_space(req): Ref python/importlib_metadata#357. """ # '@' is uniquely indicative of a url_req. - return ' ' * ('@' in req) + return " " * ("@" in req) for section in sections: space = url_req_space(section.value) @@ -752,7 +752,7 @@ def path(self): Typically refers to Python installed package paths such as "site-packages" directories and defaults to ``sys.path``. """ - return vars(self).get('path', sys.path) + return vars(self).get("path", sys.path) @abc.abstractmethod def find_distributions(self, context=Context()): @@ -786,7 +786,7 @@ def joinpath(self, child): def children(self): with suppress(Exception): - return os.listdir(self.root or '.') + return os.listdir(self.root or ".") with suppress(Exception): return self.zip_children() return [] @@ -868,7 +868,7 @@ def normalize(name): """ PEP 503 normalization plus dashes as underscores. """ - return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_') + return re.sub(r"[-_.]+", "-", name).lower().replace("-", "_") @staticmethod def legacy_normalize(name): @@ -876,7 +876,7 @@ def legacy_normalize(name): Normalize the package name as found in the convention in older packaging tools versions and specs. """ - return name.lower().replace('-', '_') + return name.lower().replace("-", "_") def __bool__(self): return bool(self.name) @@ -930,7 +930,7 @@ def read_text(self, filename): NotADirectoryError, PermissionError, ): - return self._path.joinpath(filename).read_text(encoding='utf-8') + return self._path.joinpath(filename).read_text(encoding="utf-8") read_text.__doc__ = Distribution.read_text.__doc__ @@ -948,9 +948,9 @@ def _normalized_name(self): def _name_from_stem(self, stem): name, ext = os.path.splitext(stem) - if ext not in ('.dist-info', '.egg-info'): + if ext not in (".dist-info", ".egg-info"): return - name, sep, rest = stem.partition('-') + name, sep, rest = stem.partition("-") return name @@ -1007,7 +1007,7 @@ def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: :return: EntryPoints or SelectableGroups for all installed packages. """ - norm_name = operator.attrgetter('_normalized_name') + norm_name = operator.attrgetter("_normalized_name") unique = functools.partial(unique_everseen, key=norm_name) eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) @@ -1047,17 +1047,17 @@ def packages_distributions() -> Mapping[str, List[str]]: pkg_to_dist = collections.defaultdict(list) for dist in distributions(): for pkg in _top_level_declared(dist) or _top_level_inferred(dist): - pkg_to_dist[pkg].append(dist.metadata['Name']) + pkg_to_dist[pkg].append(dist.metadata["Name"]) return dict(pkg_to_dist) def _top_level_declared(dist): - return (dist.read_text('top_level.txt') or '').split() + return (dist.read_text("top_level.txt") or "").split() def _top_level_inferred(dist): return { - f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + f.parts[0] if len(f.parts) > 1 else f.with_suffix("").name for f in always_iterable(dist.files) if f.suffix == ".py" } diff --git a/metaflow/_vendor/importlib_metadata/_adapters.py b/metaflow/_vendor/importlib_metadata/_adapters.py index aa460d3eda5..49cfa02e666 100644 --- a/metaflow/_vendor/importlib_metadata/_adapters.py +++ b/metaflow/_vendor/importlib_metadata/_adapters.py @@ -10,16 +10,16 @@ class Message(email.message.Message): map( FoldedCase, [ - 'Classifier', - 'Obsoletes-Dist', - 'Platform', - 'Project-URL', - 'Provides-Dist', - 'Provides-Extra', - 'Requires-Dist', - 'Requires-External', - 'Supported-Platform', - 'Dynamic', + "Classifier", + "Obsoletes-Dist", + "Platform", + "Project-URL", + "Provides-Dist", + "Provides-Extra", + "Requires-Dist", + "Requires-External", + "Supported-Platform", + "Dynamic", ], ) ) @@ -42,13 +42,13 @@ def __iter__(self): def _repair_headers(self): def redent(value): "Correct for RFC822 indentation" - if not value or '\n' not in value: + if not value or "\n" not in value: return value - return textwrap.dedent(' ' * 8 + value) + return textwrap.dedent(" " * 8 + value) - headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + headers = [(key, redent(value)) for key, value in vars(self)["_headers"]] if self._payload: - headers.append(('Description', self.get_payload())) + headers.append(("Description", self.get_payload())) return headers @property @@ -60,9 +60,9 @@ def json(self): def transform(key): value = self.get_all(key) if key in self.multiple_use_keys else self[key] - if key == 'Keywords': - value = re.split(r'\s+', value) - tk = key.lower().replace('-', '_') + if key == "Keywords": + value = re.split(r"\s+", value) + tk = key.lower().replace("-", "_") return tk, value return dict(map(transform, map(FoldedCase, self))) diff --git a/metaflow/_vendor/importlib_metadata/_collections.py b/metaflow/_vendor/importlib_metadata/_collections.py index cf0954e1a30..895678a23c3 100644 --- a/metaflow/_vendor/importlib_metadata/_collections.py +++ b/metaflow/_vendor/importlib_metadata/_collections.py @@ -18,13 +18,13 @@ class FreezableDefaultDict(collections.defaultdict): """ def __missing__(self, key): - return getattr(self, '_frozen', super().__missing__)(key) + return getattr(self, "_frozen", super().__missing__)(key) def freeze(self): self._frozen = lambda key: self.default_factory() -class Pair(collections.namedtuple('Pair', 'name value')): +class Pair(collections.namedtuple("Pair", "name value")): @classmethod def parse(cls, text): return cls(*map(str.strip, text.split("=", 1))) diff --git a/metaflow/_vendor/importlib_metadata/_compat.py b/metaflow/_vendor/importlib_metadata/_compat.py index 15927dbb753..631b8f185d7 100644 --- a/metaflow/_vendor/importlib_metadata/_compat.py +++ b/metaflow/_vendor/importlib_metadata/_compat.py @@ -2,7 +2,7 @@ import platform -__all__ = ['install', 'NullFinder', 'Protocol'] +__all__ = ["install", "NullFinder", "Protocol"] try: @@ -35,8 +35,8 @@ def disable_stdlib_finder(): def matches(finder): return getattr( - finder, '__module__', None - ) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions') + finder, "__module__", None + ) == "_frozen_importlib_external" and hasattr(finder, "find_distributions") for finder in filter(matches, sys.meta_path): # pragma: nocover del finder.find_distributions @@ -67,5 +67,5 @@ def pypy_partial(val): Workaround for #327. """ - is_pypy = platform.python_implementation() == 'PyPy' + is_pypy = platform.python_implementation() == "PyPy" return val + is_pypy diff --git a/metaflow/_vendor/importlib_metadata/_meta.py b/metaflow/_vendor/importlib_metadata/_meta.py index 37ee43e6ef4..bc502995bd7 100644 --- a/metaflow/_vendor/importlib_metadata/_meta.py +++ b/metaflow/_vendor/importlib_metadata/_meta.py @@ -6,17 +6,13 @@ class PackageMetadata(Protocol): - def __len__(self) -> int: - ... # pragma: no cover + def __len__(self) -> int: ... # pragma: no cover - def __contains__(self, item: str) -> bool: - ... # pragma: no cover + def __contains__(self, item: str) -> bool: ... # pragma: no cover - def __getitem__(self, key: str) -> str: - ... # pragma: no cover + def __getitem__(self, key: str) -> str: ... # pragma: no cover - def __iter__(self) -> Iterator[str]: - ... # pragma: no cover + def __iter__(self) -> Iterator[str]: ... # pragma: no cover def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: """ @@ -35,14 +31,10 @@ class SimplePath(Protocol): A minimal subset of pathlib.Path required by PathDistribution. """ - def joinpath(self) -> 'SimplePath': - ... # pragma: no cover + def joinpath(self) -> "SimplePath": ... # pragma: no cover - def __truediv__(self) -> 'SimplePath': - ... # pragma: no cover + def __truediv__(self) -> "SimplePath": ... # pragma: no cover - def parent(self) -> 'SimplePath': - ... # pragma: no cover + def parent(self) -> "SimplePath": ... # pragma: no cover - def read_text(self) -> str: - ... # pragma: no cover + def read_text(self) -> str: ... # pragma: no cover diff --git a/metaflow/_vendor/importlib_metadata/_text.py b/metaflow/_vendor/importlib_metadata/_text.py index c88cfbb2349..376210d7096 100644 --- a/metaflow/_vendor/importlib_metadata/_text.py +++ b/metaflow/_vendor/importlib_metadata/_text.py @@ -94,6 +94,6 @@ def lower(self): def index(self, sub): return self.lower().index(sub.lower()) - def split(self, splitter=' ', maxsplit=0): + def split(self, splitter=" ", maxsplit=0): pattern = re.compile(re.escape(splitter), re.I) return pattern.split(self, maxsplit) diff --git a/metaflow/_vendor/typeguard/_pytest_plugin.py b/metaflow/_vendor/typeguard/_pytest_plugin.py index 5272be04366..41500c58344 100644 --- a/metaflow/_vendor/typeguard/_pytest_plugin.py +++ b/metaflow/_vendor/typeguard/_pytest_plugin.py @@ -4,7 +4,11 @@ import warnings from typing import TYPE_CHECKING, Any, Literal -from metaflow._vendor.typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from metaflow._vendor.typeguard._config import ( + CollectionCheckStrategy, + ForwardRefPolicy, + global_config, +) from metaflow._vendor.typeguard._exceptions import InstrumentationWarning from metaflow._vendor.typeguard._importhook import install_import_hook from metaflow._vendor.typeguard._utils import qualified_name, resolve_reference diff --git a/metaflow/_vendor/typing_extensions.py b/metaflow/_vendor/typing_extensions.py index edf1805f00f..9090dc5b6c6 100644 --- a/metaflow/_vendor/typing_extensions.py +++ b/metaflow/_vendor/typing_extensions.py @@ -12,125 +12,120 @@ __all__ = [ # Super-special typing primitives. - 'Any', - 'ClassVar', - 'Concatenate', - 'Final', - 'LiteralString', - 'ParamSpec', - 'ParamSpecArgs', - 'ParamSpecKwargs', - 'Self', - 'Type', - 'TypeVar', - 'TypeVarTuple', - 'Unpack', - + "Any", + "ClassVar", + "Concatenate", + "Final", + "LiteralString", + "ParamSpec", + "ParamSpecArgs", + "ParamSpecKwargs", + "Self", + "Type", + "TypeVar", + "TypeVarTuple", + "Unpack", # ABCs (from collections.abc). - 'Awaitable', - 'AsyncIterator', - 'AsyncIterable', - 'Coroutine', - 'AsyncGenerator', - 'AsyncContextManager', - 'Buffer', - 'ChainMap', - + "Awaitable", + "AsyncIterator", + "AsyncIterable", + "Coroutine", + "AsyncGenerator", + "AsyncContextManager", + "Buffer", + "ChainMap", # Concrete collection types. - 'ContextManager', - 'Counter', - 'Deque', - 'DefaultDict', - 'NamedTuple', - 'OrderedDict', - 'TypedDict', - + "ContextManager", + "Counter", + "Deque", + "DefaultDict", + "NamedTuple", + "OrderedDict", + "TypedDict", # Structural checks, a.k.a. protocols. - 'SupportsAbs', - 'SupportsBytes', - 'SupportsComplex', - 'SupportsFloat', - 'SupportsIndex', - 'SupportsInt', - 'SupportsRound', - + "SupportsAbs", + "SupportsBytes", + "SupportsComplex", + "SupportsFloat", + "SupportsIndex", + "SupportsInt", + "SupportsRound", # One-off things. - 'Annotated', - 'assert_never', - 'assert_type', - 'clear_overloads', - 'dataclass_transform', - 'deprecated', - 'Doc', - 'get_overloads', - 'final', - 'get_args', - 'get_origin', - 'get_original_bases', - 'get_protocol_members', - 'get_type_hints', - 'IntVar', - 'is_protocol', - 'is_typeddict', - 'Literal', - 'NewType', - 'overload', - 'override', - 'Protocol', - 'reveal_type', - 'runtime', - 'runtime_checkable', - 'Text', - 'TypeAlias', - 'TypeAliasType', - 'TypeGuard', - 'TypeIs', - 'TYPE_CHECKING', - 'Never', - 'NoReturn', - 'ReadOnly', - 'Required', - 'NotRequired', - + "Annotated", + "assert_never", + "assert_type", + "clear_overloads", + "dataclass_transform", + "deprecated", + "Doc", + "get_overloads", + "final", + "get_args", + "get_origin", + "get_original_bases", + "get_protocol_members", + "get_type_hints", + "IntVar", + "is_protocol", + "is_typeddict", + "Literal", + "NewType", + "overload", + "override", + "Protocol", + "reveal_type", + "runtime", + "runtime_checkable", + "Text", + "TypeAlias", + "TypeAliasType", + "TypeGuard", + "TypeIs", + "TYPE_CHECKING", + "Never", + "NoReturn", + "ReadOnly", + "Required", + "NotRequired", # Pure aliases, have always been in typing - 'AbstractSet', - 'AnyStr', - 'BinaryIO', - 'Callable', - 'Collection', - 'Container', - 'Dict', - 'ForwardRef', - 'FrozenSet', - 'Generator', - 'Generic', - 'Hashable', - 'IO', - 'ItemsView', - 'Iterable', - 'Iterator', - 'KeysView', - 'List', - 'Mapping', - 'MappingView', - 'Match', - 'MutableMapping', - 'MutableSequence', - 'MutableSet', - 'NoDefault', - 'Optional', - 'Pattern', - 'Reversible', - 'Sequence', - 'Set', - 'Sized', - 'TextIO', - 'Tuple', - 'Union', - 'ValuesView', - 'cast', - 'no_type_check', - 'no_type_check_decorator', + "AbstractSet", + "AnyStr", + "BinaryIO", + "Callable", + "Collection", + "Container", + "Dict", + "ForwardRef", + "FrozenSet", + "Generator", + "Generic", + "Hashable", + "IO", + "ItemsView", + "Iterable", + "Iterator", + "KeysView", + "List", + "Mapping", + "MappingView", + "Match", + "MutableMapping", + "MutableSequence", + "MutableSet", + "NoDefault", + "Optional", + "Pattern", + "Reversible", + "Sequence", + "Set", + "Sized", + "TextIO", + "Tuple", + "Union", + "ValuesView", + "cast", + "no_type_check", + "no_type_check_decorator", ] # for backward compatibility @@ -151,14 +146,19 @@ def __repr__(self): if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): return isinstance( t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) ) + elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) + else: + def _should_collect_from_parameters(t): return isinstance(t, typing._GenericAlias) and not t._special @@ -167,11 +167,11 @@ def _should_collect_from_parameters(t): # Some unconstrained type variables. These are used by the container types. # (These are not for export.) -T = typing.TypeVar('T') # Any type. -KT = typing.TypeVar('KT') # Key type. -VT = typing.TypeVar('VT') # Value type. -T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. -T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. +T = typing.TypeVar("T") # Any type. +KT = typing.TypeVar("KT") # Key type. +VT = typing.TypeVar("VT") # Value type. +T_co = typing.TypeVar("T_co", covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar("T_contra", contravariant=True) # Ditto contravariant. if sys.version_info >= (3, 11): @@ -181,7 +181,9 @@ def _should_collect_from_parameters(t): class _AnyMeta(type): def __instancecheck__(self, obj): if self is Any: - raise TypeError("typing_extensions.Any cannot be used with isinstance()") + raise TypeError( + "typing_extensions.Any cannot be used with isinstance()" + ) return super().__instancecheck__(obj) def __repr__(self): @@ -198,6 +200,7 @@ class Any(metaclass=_AnyMeta): static type checkers. At runtime, Any should not be used with instance checks. """ + def __new__(cls, *args, **kwargs): if cls is Any: raise TypeError("Any cannot be instantiated") @@ -209,7 +212,7 @@ def __new__(cls, *args, **kwargs): class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name Final = typing.Final @@ -260,6 +263,7 @@ def IntVar(name): if sys.version_info >= (3, 10, 1): Literal = typing.Literal else: + def _flatten_literal_params(parameters): """An internal helper for Literal creation: flatten Literals among parameters""" params = [] @@ -287,7 +291,7 @@ def __hash__(self): class _LiteralForm(_ExtensionsSpecialForm, _root=True): def __init__(self, doc: str): - self._name = 'Literal' + self._name = "Literal" self._doc = self.__doc__ = doc def __getitem__(self, parameters): @@ -315,7 +319,8 @@ def __getitem__(self, parameters): return _LiteralGenericAlias(self, parameters) - Literal = _LiteralForm(doc="""\ + Literal = _LiteralForm( + doc="""\ A type that can be used to indicate to type checkers that the corresponding value has a value literally equivalent to the provided parameter. For example: @@ -327,7 +332,8 @@ def __getitem__(self, parameters): Literal[...] cannot be subclassed. There is no runtime checking verifying that the parameter is actually a value - instead of a type.""") + instead of a type.""" + ) _overload_dummy = typing._overload_dummy @@ -420,8 +426,9 @@ def clear_overloads(): if sys.version_info >= (3, 13, 0, "beta"): from typing import AsyncContextManager, AsyncGenerator, ContextManager, Generator else: + def _is_dunder(attr): - return attr.startswith('__') and attr.endswith('__') + return attr.startswith("__") and attr.endswith("__") # Python <3.9 doesn't have typing._SpecialGenericAlias _special_generic_alias_base = getattr( @@ -441,7 +448,7 @@ def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): self._defaults = defaults def __setattr__(self, attr, val): - allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'} + allowed_attrs = {"_name", "_inst", "_nparams", "_defaults"} if _special_generic_alias_base is typing._GenericAlias: # Python <3.9 allowed_attrs.add("__origin__") @@ -461,7 +468,7 @@ def __getitem__(self, params): and len(params) < self._nparams and len(params) + len(self._defaults) >= self._nparams ): - params = (*params, *self._defaults[len(params) - self._nparams:]) + params = (*params, *self._defaults[len(params) - self._nparams :]) actual_len = len(params) if actual_len != self._nparams: @@ -489,28 +496,39 @@ def __getitem__(self, params): contextlib.AbstractContextManager, 2, name="ContextManager", - defaults=(typing.Optional[bool],) + defaults=(typing.Optional[bool],), ) AsyncContextManager = _SpecialGenericAlias( contextlib.AbstractAsyncContextManager, 2, name="AsyncContextManager", - defaults=(typing.Optional[bool],) + defaults=(typing.Optional[bool],), ) _PROTO_ALLOWLIST = { - 'collections.abc': [ - 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Buffer', + "collections.abc": [ + "Callable", + "Awaitable", + "Iterable", + "Iterator", + "AsyncIterable", + "Hashable", + "Sized", + "Container", + "Collection", + "Reversible", + "Buffer", ], - 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], - 'typing_extensions': ['Buffer'], + "contextlib": ["AbstractContextManager", "AbstractAsyncContextManager"], + "typing_extensions": ["Buffer"], } _EXCLUDED_ATTRS = frozenset(typing.EXCLUDED_ATTRIBUTES) | { - "__match_args__", "__protocol_attrs__", "__non_callable_proto_members__", + "__match_args__", + "__protocol_attrs__", + "__non_callable_proto_members__", "__final__", } @@ -518,18 +536,18 @@ def __getitem__(self, params): def _get_protocol_attrs(cls): attrs = set() for base in cls.__mro__[:-1]: # without object - if base.__name__ in {'Protocol', 'Generic'}: + if base.__name__ in {"Protocol", "Generic"}: continue - annotations = getattr(base, '__annotations__', {}) + annotations = getattr(base, "__annotations__", {}) for attr in (*base.__dict__, *annotations): - if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): + if not attr.startswith("_abc_") and attr not in _EXCLUDED_ATTRS: attrs.add(attr) return attrs def _caller(depth=2): try: - return sys._getframe(depth).f_globals.get('__name__', '__main__') + return sys._getframe(depth).f_globals.get("__name__", "__main__") except (AttributeError, ValueError): # For platforms without _getframe() return None @@ -539,16 +557,17 @@ def _caller(depth=2): if sys.version_info >= (3, 13): Protocol = typing.Protocol else: + def _allow_reckless_class_checks(depth=3): """Allow instance and class checks for special stdlib modules. The abc and functools modules indiscriminately call isinstance() and issubclass() on the whole MRO of a user class, which may contain protocols. """ - return _caller(depth) in {'abc', 'functools', None} + return _caller(depth) in {"abc", "functools", None} def _no_init(self, *args, **kwargs): if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') + raise TypeError("Protocols cannot be instantiated") def _type_check_issubclass_arg_1(arg): """Raise TypeError if `arg` is not an instance of `type` @@ -564,7 +583,7 @@ def _type_check_issubclass_arg_1(arg): """ if not isinstance(arg, type): # Same error message as for issubclass(1, int). - raise TypeError('issubclass() arg 1 must be a class') + raise TypeError("issubclass() arg 1 must be a class") # Inheriting from typing._ProtocolMeta isn't actually desirable, # but is necessary to allow typing.Protocol and typing_extensions.Protocol @@ -601,10 +620,10 @@ def __subclasscheck__(cls, other): if cls is Protocol: return type.__subclasscheck__(cls, other) if ( - getattr(cls, '_is_protocol', False) + getattr(cls, "_is_protocol", False) and not _allow_reckless_class_checks() ): - if not getattr(cls, '_is_runtime_protocol', False): + if not getattr(cls, "_is_runtime_protocol", False): _type_check_issubclass_arg_1(other) raise TypeError( "Instance and class checks can only be used with " @@ -633,11 +652,13 @@ def __instancecheck__(cls, instance): return abc.ABCMeta.__instancecheck__(cls, instance) if ( - not getattr(cls, '_is_runtime_protocol', False) and - not _allow_reckless_class_checks() + not getattr(cls, "_is_runtime_protocol", False) + and not _allow_reckless_class_checks() ): - raise TypeError("Instance and class checks can only be used with" - " @runtime_checkable protocols") + raise TypeError( + "Instance and class checks can only be used with" + " @runtime_checkable protocols" + ) if abc.ABCMeta.__instancecheck__(cls, instance): return True @@ -671,7 +692,7 @@ def __hash__(cls) -> int: @classmethod def _proto_hook(cls, other): - if not cls.__dict__.get('_is_protocol', False): + if not cls.__dict__.get("_is_protocol", False): return NotImplemented for attr in cls.__protocol_attrs__: @@ -683,7 +704,7 @@ def _proto_hook(cls, other): break # ...or in annotations, if it is a sub-protocol. - annotations = getattr(base, '__annotations__', {}) + annotations = getattr(base, "__annotations__", {}) if ( isinstance(annotations, collections.abc.Mapping) and attr in annotations @@ -704,11 +725,11 @@ def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', False): + if not cls.__dict__.get("_is_protocol", False): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. - if '__subclasshook__' not in cls.__dict__: + if "__subclasshook__" not in cls.__dict__: cls.__subclasshook__ = _proto_hook # Prohibit instantiation for protocol classes @@ -719,6 +740,7 @@ def __init_subclass__(cls, *args, **kwargs): if sys.version_info >= (3, 13): runtime_checkable = typing.runtime_checkable else: + def runtime_checkable(cls): """Mark a protocol class as a runtime protocol. @@ -738,9 +760,13 @@ def close(self): ... Warning: this will check only the presence of the required methods, not their type signatures! """ - if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False): - raise TypeError(f'@runtime_checkable can be only applied to protocol classes,' - f' got {cls!r}') + if not issubclass(cls, typing.Generic) or not getattr( + cls, "_is_protocol", False + ): + raise TypeError( + f"@runtime_checkable can be only applied to protocol classes," + f" got {cls!r}" + ) cls._is_runtime_protocol = True # typing.Protocol classes on <=3.11 break if we execute this block, @@ -785,9 +811,11 @@ def close(self): ... SupportsAbs = typing.SupportsAbs SupportsRound = typing.SupportsRound else: + @runtime_checkable class SupportsInt(Protocol): """An ABC with one abstract method __int__.""" + __slots__ = () @abc.abstractmethod @@ -797,6 +825,7 @@ def __int__(self) -> int: @runtime_checkable class SupportsFloat(Protocol): """An ABC with one abstract method __float__.""" + __slots__ = () @abc.abstractmethod @@ -806,6 +835,7 @@ def __float__(self) -> float: @runtime_checkable class SupportsComplex(Protocol): """An ABC with one abstract method __complex__.""" + __slots__ = () @abc.abstractmethod @@ -815,6 +845,7 @@ def __complex__(self) -> complex: @runtime_checkable class SupportsBytes(Protocol): """An ABC with one abstract method __bytes__.""" + __slots__ = () @abc.abstractmethod @@ -834,6 +865,7 @@ class SupportsAbs(Protocol[T_co]): """ An ABC with one abstract method __abs__ that is covariant in its return type. """ + __slots__ = () @abc.abstractmethod @@ -845,6 +877,7 @@ class SupportsRound(Protocol[T_co]): """ An ABC with one abstract method __round__ that is covariant in its return type. """ + __slots__ = () @abc.abstractmethod @@ -857,13 +890,14 @@ def inner(func): if sys.implementation.name == "pypy" and sys.version_info < (3, 9): cls_dict = { "__call__": staticmethod(func), - "__mro_entries__": staticmethod(mro_entries) + "__mro_entries__": staticmethod(mro_entries), } t = type(func.__name__, (), cls_dict) return functools.update_wrapper(t(), func) else: func.__mro_entries__ = mro_entries return func + return inner @@ -902,13 +936,13 @@ def _get_typeddict_qualifiers(annotation_type): break elif annotation_origin is Required: yield Required - annotation_type, = get_args(annotation_type) + (annotation_type,) = get_args(annotation_type) elif annotation_origin is NotRequired: yield NotRequired - annotation_type, = get_args(annotation_type) + (annotation_type,) = get_args(annotation_type) elif annotation_origin is ReadOnly: yield ReadOnly - annotation_type, = get_args(annotation_type) + (annotation_type,) = get_args(annotation_type) else: break @@ -923,8 +957,10 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): """ for base in bases: if type(base) is not _TypedDictMeta and base is not typing.Generic: - raise TypeError('cannot inherit from both a TypedDict type ' - 'and a non-TypedDict base class') + raise TypeError( + "cannot inherit from both a TypedDict type " + "and a non-TypedDict base class" + ) if any(issubclass(b, typing.Generic) for b in bases): generic_base = (typing.Generic,) @@ -933,12 +969,14 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): # typing.py generally doesn't let you inherit from plain Generic, unless # the name of the class happens to be "Protocol" - tp_dict = type.__new__(_TypedDictMeta, "Protocol", (*generic_base, dict), ns) + tp_dict = type.__new__( + _TypedDictMeta, "Protocol", (*generic_base, dict), ns + ) tp_dict.__name__ = name if tp_dict.__qualname__ == "Protocol": tp_dict.__qualname__ = name - if not hasattr(tp_dict, '__orig_bases__'): + if not hasattr(tp_dict, "__orig_bases__"): tp_dict.__orig_bases__ = bases annotations = {} @@ -957,8 +995,7 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): } else: own_annotations = { - n: typing._type_check(tp, msg) - for n, tp in own_annotations.items() + n: typing._type_check(tp, msg) for n, tp in own_annotations.items() } required_keys = set() optional_keys = set() @@ -969,12 +1006,12 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): for base in bases: base_dict = base.__dict__ - annotations.update(base_dict.get('__annotations__', {})) - required_keys.update(base_dict.get('__required_keys__', ())) - optional_keys.update(base_dict.get('__optional_keys__', ())) - readonly_keys.update(base_dict.get('__readonly_keys__', ())) - mutable_keys.update(base_dict.get('__mutable_keys__', ())) - base_extra_items_type = base_dict.get('__extra_items__', None) + annotations.update(base_dict.get("__annotations__", {})) + required_keys.update(base_dict.get("__required_keys__", ())) + optional_keys.update(base_dict.get("__optional_keys__", ())) + readonly_keys.update(base_dict.get("__readonly_keys__", ())) + mutable_keys.update(base_dict.get("__mutable_keys__", ())) + base_extra_items_type = base_dict.get("__extra_items__", None) if base_extra_items_type is not None: extra_items_type = base_extra_items_type @@ -985,13 +1022,11 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): qualifiers = set(_get_typeddict_qualifiers(annotation_type)) if Required in qualifiers: raise TypeError( - "Special key __extra_items__ does not support " - "Required" + "Special key __extra_items__ does not support " "Required" ) if NotRequired in qualifiers: raise TypeError( - "Special key __extra_items__ does not support " - "NotRequired" + "Special key __extra_items__ does not support " "NotRequired" ) extra_items_type = annotation_type @@ -1019,7 +1054,7 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): tp_dict.__optional_keys__ = frozenset(optional_keys) tp_dict.__readonly_keys__ = frozenset(readonly_keys) tp_dict.__mutable_keys__ = frozenset(mutable_keys) - if not hasattr(tp_dict, '__total__'): + if not hasattr(tp_dict, "__total__"): tp_dict.__total__ = total tp_dict.__closed__ = closed tp_dict.__extra_items__ = extra_items_type @@ -1029,11 +1064,11 @@ def __new__(cls, name, bases, ns, *, total=True, closed=False): def __subclasscheck__(cls, other): # Typed dicts are only for static structural subtyping. - raise TypeError('TypedDict does not support instance and class checks') + raise TypeError("TypedDict does not support instance and class checks") __instancecheck__ = __subclasscheck__ - _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) + _TypedDict = type.__new__(_TypedDictMeta, "TypedDict", (), {}) @_ensure_subclassable(lambda bases: (_TypedDict,)) def TypedDict(typename, fields=_marker, /, *, total=True, closed=False, **kwargs): @@ -1091,18 +1126,23 @@ class Point2D(TypedDict): example = f"`{typename} = TypedDict({typename!r}, {{}})`" deprecation_msg = ( - f"{deprecated_thing} is deprecated and will be disallowed in " - "Python 3.15. To create a TypedDict class with 0 fields " - "using the functional syntax, pass an empty dictionary, e.g. " - ) + example + "." + ( + f"{deprecated_thing} is deprecated and will be disallowed in " + "Python 3.15. To create a TypedDict class with 0 fields " + "using the functional syntax, pass an empty dictionary, e.g. " + ) + + example + + "." + ) warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) if closed is not False and closed is not True: kwargs["closed"] = closed closed = False fields = kwargs elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") + raise TypeError( + "TypedDict takes either a dict or keyword arguments," " but not both" + ) if kwargs: if sys.version_info >= (3, 13): raise TypeError("TypedDict takes no keyword arguments") @@ -1114,11 +1154,11 @@ class Point2D(TypedDict): stacklevel=2, ) - ns = {'__annotations__': dict(fields)} + ns = {"__annotations__": dict(fields)} module = _caller() if module is not None: # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = module + ns["__module__"] = module td = _TypedDictMeta(typename, (), ns, total=total, closed=closed) td.__orig_bases__ = (TypedDict,) @@ -1150,6 +1190,7 @@ class Film(TypedDict): assert_type = typing.assert_type else: + def assert_type(val, typ, /): """Assert (to the type checker) that the value is of the given type. @@ -1174,7 +1215,11 @@ def _strip_extras(t): """Strips Annotated, Required and NotRequired from a given type.""" if isinstance(t, _AnnotatedAlias): return _strip_extras(t.__origin__) - if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly): + if hasattr(t, "__origin__") and t.__origin__ in ( + Required, + NotRequired, + ReadOnly, + ): return _strip_extras(t.__args__[0]) if isinstance(t, typing._GenericAlias): stripped_args = tuple(_strip_extras(a) for a in t.__args__) @@ -1238,13 +1283,14 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): # Python 3.9+ has PEP 593 (Annotated) -if hasattr(typing, 'Annotated'): +if hasattr(typing, "Annotated"): Annotated = typing.Annotated # Not exported and not a public API, but needed for get_origin() and get_args() # to work. _AnnotatedAlias = typing._AnnotatedAlias # 3.8 else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): """Runtime representation of an annotated type. @@ -1253,6 +1299,7 @@ class _AnnotatedAlias(typing._GenericAlias, _root=True): instantiating is the same as instantiating the underlying type, binding it to types is also the same. """ + def __init__(self, origin, metadata): if isinstance(origin, _AnnotatedAlias): metadata = origin.__metadata__ + metadata @@ -1266,13 +1313,13 @@ def copy_with(self, params): return _AnnotatedAlias(new_type, self.__metadata__) def __repr__(self): - return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " - f"{', '.join(repr(a) for a in self.__metadata__)}]") + return ( + f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]" + ) def __reduce__(self): - return operator.getitem, ( - Annotated, (self.__origin__, *self.__metadata__) - ) + return operator.getitem, (Annotated, (self.__origin__, *self.__metadata__)) def __eq__(self, other): if not isinstance(other, _AnnotatedAlias): @@ -1325,9 +1372,11 @@ def __new__(cls, *args, **kwargs): @typing._tp_cache def __class_getitem__(cls, params): if not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be used " - "with at least two arguments (a type and an " - "annotation).") + raise TypeError( + "Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation)." + ) allowed_special_forms = (ClassVar, Final) if get_origin(params[0]) in allowed_special_forms: origin = params[0] @@ -1338,9 +1387,8 @@ def __class_getitem__(cls, params): return _AnnotatedAlias(origin, metadata) def __init_subclass__(cls, *args, **kwargs): - raise TypeError( - f"Cannot subclass {cls.__module__}.Annotated" - ) + raise TypeError(f"Cannot subclass {cls.__module__}.Annotated") + # Python 3.8 has get_origin() and get_args() but those implementations aren't # Annotated-aware, so we can't use those. Python 3.9's versions don't support @@ -1378,8 +1426,16 @@ def get_origin(tp): """ if isinstance(tp, _AnnotatedAlias): return Annotated - if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, - ParamSpecArgs, ParamSpecKwargs)): + if isinstance( + tp, + ( + typing._GenericAlias, + _typing_GenericAlias, + _BaseGenericAlias, + ParamSpecArgs, + ParamSpecKwargs, + ), + ): return tp.__origin__ if tp is typing.Generic: return typing.Generic @@ -1409,10 +1465,11 @@ def get_args(tp): # 3.10+ -if hasattr(typing, 'TypeAlias'): +if hasattr(typing, "TypeAlias"): TypeAlias = typing.TypeAlias # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def TypeAlias(self, parameters): """Special marker indicating that an assignment should @@ -1426,10 +1483,12 @@ def TypeAlias(self, parameters): It's invalid when used anywhere except as in the example above. """ raise TypeError(f"{self} is not subscriptable") + + # 3.8 else: TypeAlias = _ExtensionsSpecialForm( - 'TypeAlias', + "TypeAlias", doc="""Special marker indicating that an assignment should be recognized as a proper type alias definition by type checkers. @@ -1439,13 +1498,14 @@ def TypeAlias(self, parameters): Predicate: TypeAlias = Callable[..., bool] It's invalid when used anywhere except as in the example - above.""" + above.""", ) if hasattr(typing, "NoDefault"): NoDefault = typing.NoDefault else: + class NoDefaultTypeMeta(type): def __setattr__(cls, attr, value): # TypeError is consistent with the behavior of NoneType @@ -1479,7 +1539,7 @@ def _set_default(type_param, default): def _set_module(typevarlike): # for pickling: def_mod = _caller(depth=3) - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": typevarlike.__module__ = def_mod @@ -1505,28 +1565,46 @@ class TypeVar(metaclass=_TypeVarLikeMeta): _backported_typevarlike = typing.TypeVar - def __new__(cls, name, *constraints, bound=None, - covariant=False, contravariant=False, - default=NoDefault, infer_variance=False): + def __new__( + cls, + name, + *constraints, + bound=None, + covariant=False, + contravariant=False, + default=NoDefault, + infer_variance=False, + ): if hasattr(typing, "TypeAliasType"): # PEP 695 implemented (3.12+), can pass infer_variance to typing.TypeVar - typevar = typing.TypeVar(name, *constraints, bound=bound, - covariant=covariant, contravariant=contravariant, - infer_variance=infer_variance) + typevar = typing.TypeVar( + name, + *constraints, + bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance, + ) else: - typevar = typing.TypeVar(name, *constraints, bound=bound, - covariant=covariant, contravariant=contravariant) + typevar = typing.TypeVar( + name, + *constraints, + bound=bound, + covariant=covariant, + contravariant=contravariant, + ) if infer_variance and (covariant or contravariant): - raise ValueError("Variance cannot be specified with infer_variance.") + raise ValueError( + "Variance cannot be specified with infer_variance." + ) typevar.__infer_variance__ = infer_variance _set_default(typevar, default) _set_module(typevar) def _tvar_prepare_subst(alias, args): - if ( - typevar.has_default() - and alias.__parameters__.index(typevar) == len(args) + if typevar.has_default() and alias.__parameters__.index(typevar) == len( + args ): args += (typevar.__default__,) return args @@ -1539,13 +1617,15 @@ def __init_subclass__(cls) -> None: # Python 3.10+ has PEP 612 -if hasattr(typing, 'ParamSpecArgs'): +if hasattr(typing, "ParamSpecArgs"): ParamSpecArgs = typing.ParamSpecArgs ParamSpecKwargs = typing.ParamSpecKwargs # 3.8-3.9 else: + class _Immutable: """Mixin to indicate that object should not be copied.""" + __slots__ = () def __copy__(self): @@ -1566,6 +1646,7 @@ class ParamSpecArgs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1589,6 +1670,7 @@ class ParamSpecKwargs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1605,7 +1687,7 @@ def __eq__(self, other): from typing import ParamSpec # 3.10+ -elif hasattr(typing, 'ParamSpec'): +elif hasattr(typing, "ParamSpec"): # Add default parameter - PEP 696 class ParamSpec(metaclass=_TypeVarLikeMeta): @@ -1613,19 +1695,29 @@ class ParamSpec(metaclass=_TypeVarLikeMeta): _backported_typevarlike = typing.ParamSpec - def __new__(cls, name, *, bound=None, - covariant=False, contravariant=False, - infer_variance=False, default=NoDefault): + def __new__( + cls, + name, + *, + bound=None, + covariant=False, + contravariant=False, + infer_variance=False, + default=NoDefault, + ): if hasattr(typing, "TypeAliasType"): # PEP 695 implemented, can pass infer_variance to typing.TypeVar - paramspec = typing.ParamSpec(name, bound=bound, - covariant=covariant, - contravariant=contravariant, - infer_variance=infer_variance) + paramspec = typing.ParamSpec( + name, + bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance, + ) else: - paramspec = typing.ParamSpec(name, bound=bound, - covariant=covariant, - contravariant=contravariant) + paramspec = typing.ParamSpec( + name, bound=bound, covariant=covariant, contravariant=contravariant + ) paramspec.__infer_variance__ = infer_variance _set_default(paramspec, default) @@ -1644,14 +1736,17 @@ def _paramspec_prepare_subst(alias, args): args = (args,) # Convert lists to tuples to help other libraries cache the results. elif isinstance(args[i], list): - args = (*args[:i], tuple(args[i]), *args[i + 1:]) + args = (*args[:i], tuple(args[i]), *args[i + 1 :]) return args paramspec.__typing_prepare_subst__ = _paramspec_prepare_subst return paramspec def __init_subclass__(cls) -> None: - raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + raise TypeError( + f"type '{__name__}.ParamSpec' is not an acceptable base type" + ) + # 3.8-3.9 else: @@ -1715,33 +1810,41 @@ def args(self): def kwargs(self): return ParamSpecKwargs(self) - def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - infer_variance=False, default=NoDefault): + def __init__( + self, + name, + *, + bound=None, + covariant=False, + contravariant=False, + infer_variance=False, + default=NoDefault, + ): list.__init__(self, [self]) self.__name__ = name self.__covariant__ = bool(covariant) self.__contravariant__ = bool(contravariant) self.__infer_variance__ = bool(infer_variance) if bound: - self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + self.__bound__ = typing._type_check(bound, "Bound must be a type.") else: self.__bound__ = None _DefaultMixin.__init__(self, default) # for pickling: def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod def __repr__(self): if self.__infer_variance__: - prefix = '' + prefix = "" elif self.__covariant__: - prefix = '+' + prefix = "+" elif self.__contravariant__: - prefix = '-' + prefix = "-" else: - prefix = '~' + prefix = "~" return prefix + self.__name__ def __hash__(self): @@ -1759,7 +1862,7 @@ def __call__(self, *args, **kwargs): # 3.8-3.9 -if not hasattr(typing, 'Concatenate'): +if not hasattr(typing, "Concatenate"): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. class _ConcatenateGenericAlias(list): @@ -1776,8 +1879,10 @@ def __init__(self, origin, args): def __repr__(self): _type_repr = typing._type_repr - return (f'{_type_repr(self.__origin__)}' - f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + return ( + f"{_type_repr(self.__origin__)}" + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]' + ) def __hash__(self): return hash((self.__origin__, self.__args__)) @@ -1789,7 +1894,9 @@ def __call__(self, *args, **kwargs): @property def __parameters__(self): return tuple( - tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + tp + for tp in self.__args__ + if isinstance(tp, (typing.TypeVar, ParamSpec)) ) @@ -1801,19 +1908,21 @@ def _concatenate_getitem(self, parameters): if not isinstance(parameters, tuple): parameters = (parameters,) if not isinstance(parameters[-1], ParamSpec): - raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") + raise TypeError( + "The last parameter to Concatenate should be a " "ParamSpec variable." + ) msg = "Concatenate[arg, ...]: each arg must be a type." parameters = tuple(typing._type_check(p, msg) for p in parameters) return _ConcatenateGenericAlias(self, parameters) # 3.10+ -if hasattr(typing, 'Concatenate'): +if hasattr(typing, "Concatenate"): Concatenate = typing.Concatenate _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def Concatenate(self, parameters): """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a @@ -1827,14 +1936,17 @@ def Concatenate(self, parameters): See PEP 612 for detailed information. """ return _concatenate_getitem(self, parameters) + + # 3.8 else: + class _ConcatenateForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): return _concatenate_getitem(self, parameters) Concatenate = _ConcatenateForm( - 'Concatenate', + "Concatenate", doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a higher order function which adds, removes or transforms parameters of a callable. @@ -1844,13 +1956,15 @@ def __getitem__(self, parameters): Callable[Concatenate[int, P], int] See PEP 612 for detailed information. - """) + """, + ) # 3.10+ -if hasattr(typing, 'TypeGuard'): +if hasattr(typing, "TypeGuard"): TypeGuard = typing.TypeGuard # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def TypeGuard(self, parameters): """Special typing form used to annotate the return type of a user-defined @@ -1895,18 +2009,22 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). """ - item = typing._type_check(parameters, f'{self} accepts only a single type.') + item = typing._type_check(parameters, f"{self} accepts only a single type.") return typing._GenericAlias(self, (item,)) + + # 3.8 else: + class _TypeGuardForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type" + ) return typing._GenericAlias(self, (item,)) TypeGuard = _TypeGuardForm( - 'TypeGuard', + "TypeGuard", doc="""Special typing form used to annotate the return type of a user-defined type guard function. ``TypeGuard`` only accepts a single type argument. At runtime, functions marked this way should return a boolean. @@ -1948,13 +2066,15 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). - """) + """, + ) # 3.13+ -if hasattr(typing, 'TypeIs'): +if hasattr(typing, "TypeIs"): TypeIs = typing.TypeIs # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def TypeIs(self, parameters): """Special typing form used to annotate the return type of a user-defined @@ -1993,18 +2113,22 @@ def f(val: Union[int, Awaitable[int]]) -> int: ``TypeIs`` also works with type variables. For more information, see PEP 742 (Narrowing types with TypeIs). """ - item = typing._type_check(parameters, f'{self} accepts only a single type.') + item = typing._type_check(parameters, f"{self} accepts only a single type.") return typing._GenericAlias(self, (item,)) + + # 3.8 else: + class _TypeIsForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type" + ) return typing._GenericAlias(self, (item,)) TypeIs = _TypeIsForm( - 'TypeIs', + "TypeIs", doc="""Special typing form used to annotate the return type of a user-defined type narrower function. ``TypeIs`` only accepts a single type argument. At runtime, functions marked this way should return a boolean. @@ -2040,12 +2164,13 @@ def f(val: Union[int, Awaitable[int]]) -> int: ``TypeIs`` also works with type variables. For more information, see PEP 742 (Narrowing types with TypeIs). - """) + """, + ) # Vendored from cpython typing._SpecialFrom class _SpecialForm(typing._Final, _root=True): - __slots__ = ('_name', '__doc__', '_getitem') + __slots__ = ("_name", "__doc__", "_getitem") def __init__(self, getitem): self._getitem = getitem @@ -2053,7 +2178,7 @@ def __init__(self, getitem): self.__doc__ = getitem.__doc__ def __getattr__(self, item): - if item in {'__name__', '__qualname__'}: + if item in {"__name__", "__qualname__"}: return self._name raise AttributeError(item) @@ -2062,7 +2187,7 @@ def __mro_entries__(self, bases): raise TypeError(f"Cannot subclass {self!r}") def __repr__(self): - return f'typing_extensions.{self._name}' + return f"typing_extensions.{self._name}" def __reduce__(self): return self._name @@ -2090,6 +2215,7 @@ def __getitem__(self, parameters): if hasattr(typing, "LiteralString"): # 3.11+ LiteralString = typing.LiteralString else: + @_SpecialForm def LiteralString(self, params): """Represents an arbitrary literal string. @@ -2113,6 +2239,7 @@ def query(sql: LiteralString) -> ...: if hasattr(typing, "Self"): # 3.11+ Self = typing.Self else: + @_SpecialForm def Self(self, params): """Used to spell the type of "self" in classes. @@ -2134,6 +2261,7 @@ def parse(self, data: bytes) -> Self: if hasattr(typing, "Never"): # 3.11+ Never = typing.Never else: + @_SpecialForm def Never(self, params): """The bottom type, a type that has no members. @@ -2161,10 +2289,11 @@ def int_or_str(arg: int | str) -> None: raise TypeError(f"{self} is not subscriptable") -if hasattr(typing, 'Required'): # 3.11+ +if hasattr(typing, "Required"): # 3.11+ Required = typing.Required NotRequired = typing.NotRequired elif sys.version_info[:2] >= (3, 9): # 3.9-3.10 + @_ExtensionsSpecialForm def Required(self, parameters): """A special typing construct to mark a key of a total=False TypedDict @@ -2182,7 +2311,9 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) @_ExtensionsSpecialForm @@ -2199,18 +2330,22 @@ class Movie(TypedDict): year=1999, ) """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) else: # 3.8 + class _RequiredForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) Required = _RequiredForm( - 'Required', + "Required", doc="""A special typing construct to mark a key of a total=False TypedDict as required. For example: @@ -2225,9 +2360,10 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. - """) + """, + ) NotRequired = _RequiredForm( - 'NotRequired', + "NotRequired", doc="""A special typing construct to mark a key of a TypedDict as potentially missing. For example: @@ -2239,12 +2375,14 @@ class Movie(TypedDict): title='The Matrix', # typechecker error if key is omitted year=1999, ) - """) + """, + ) -if hasattr(typing, 'ReadOnly'): +if hasattr(typing, "ReadOnly"): ReadOnly = typing.ReadOnly elif sys.version_info[:2] >= (3, 9): # 3.9-3.12 + @_ExtensionsSpecialForm def ReadOnly(self, parameters): """A special typing construct to mark an item of a TypedDict as read-only. @@ -2261,18 +2399,22 @@ def mutate_movie(m: Movie) -> None: There is no runtime checking for this property. """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) else: # 3.8 + class _ReadOnlyForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) ReadOnly = _ReadOnlyForm( - 'ReadOnly', + "ReadOnly", doc="""A special typing construct to mark a key of a TypedDict as read-only. For example: @@ -2286,7 +2428,8 @@ def mutate_movie(m: Movie) -> None: m["title"] = "The Matrix" # typechecker error There is no runtime checking for this propery. - """) + """, + ) _UNPACK_DOC = """\ @@ -2338,6 +2481,7 @@ def _is_unpack(obj): return get_origin(obj) is Unpack elif sys.version_info[:2] >= (3, 9): # 3.9+ + class _UnpackSpecialForm(_ExtensionsSpecialForm, _root=True): def __init__(self, getitem): super().__init__(getitem) @@ -2350,7 +2494,7 @@ class _UnpackAlias(typing._GenericAlias, _root=True): def __typing_unpacked_tuple_args__(self): assert self.__origin__ is Unpack assert len(self.__args__) == 1 - arg, = self.__args__ + (arg,) = self.__args__ if isinstance(arg, (typing._GenericAlias, _types.GenericAlias)): if arg.__origin__ is not tuple: raise TypeError("Unpack[...] must be used with a tuple type") @@ -2359,23 +2503,27 @@ def __typing_unpacked_tuple_args__(self): @_UnpackSpecialForm def Unpack(self, parameters): - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return _UnpackAlias(self, (item,)) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) else: # 3.8 + class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar class _UnpackForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return _UnpackAlias(self, (item,)) - Unpack = _UnpackForm('Unpack', doc=_UNPACK_DOC) + Unpack = _UnpackForm("Unpack", doc=_UNPACK_DOC) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) @@ -2389,7 +2537,7 @@ def _is_unpack(obj): def _unpack_args(*args): newargs = [] for arg in args: - subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + subargs = getattr(arg, "__typing_unpacked_tuple_args__", None) if subargs is not None and not (subargs and subargs[-1] is ...): newargs.extend(subargs) else: @@ -2410,7 +2558,7 @@ def __new__(cls, name, *, default=NoDefault): def _typevartuple_prepare_subst(alias, args): params = alias.__parameters__ typevartuple_index = params.index(tvt) - for param in params[typevartuple_index + 1:]: + for param in params[typevartuple_index + 1 :]: if isinstance(param, TypeVarTuple): raise TypeError( f"More than one TypeVarTuple parameter in {alias}" @@ -2424,7 +2572,7 @@ def _typevartuple_prepare_subst(alias, args): fillarg = None for k, arg in enumerate(args): if not isinstance(arg, type): - subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + subargs = getattr(arg, "__typing_unpacked_tuple_args__", None) if subargs and len(subargs) == 2 and subargs[-1] is ...: if var_tuple_index is not None: raise TypeError( @@ -2437,19 +2585,21 @@ def _typevartuple_prepare_subst(alias, args): left = min(left, var_tuple_index) right = min(right, alen - var_tuple_index - 1) elif left + right > alen: - raise TypeError(f"Too few arguments for {alias};" - f" actual {alen}, expected at least {plen - 1}") + raise TypeError( + f"Too few arguments for {alias};" + f" actual {alen}, expected at least {plen - 1}" + ) if left == alen - right and tvt.has_default(): replacement = _unpack_args(tvt.__default__) else: - replacement = args[left: alen - right] + replacement = args[left : alen - right] return ( *args[:left], *([fillarg] * (typevartuple_index - left)), replacement, *([fillarg] * (plen - right - left - typevartuple_index - 1)), - *args[alen - right:], + *args[alen - right :], ) tvt.__typing_prepare_subst__ = _typevartuple_prepare_subst @@ -2459,6 +2609,7 @@ def __init_subclass__(self, *args, **kwds): raise TypeError("Cannot subclass special typing classes") else: # <=3.10 + class TypeVarTuple(_DefaultMixin): """Type variable tuple. @@ -2515,7 +2666,7 @@ def __init__(self, name, *, default=NoDefault): # for pickling: def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod self.__unpacked__ = Unpack[self] @@ -2533,13 +2684,14 @@ def __reduce__(self): return self.__name__ def __init_subclass__(self, *args, **kwds): - if '_root' not in kwds: + if "_root" not in kwds: raise TypeError("Cannot subclass special typing classes") if hasattr(typing, "reveal_type"): # 3.11+ reveal_type = typing.reveal_type else: # <=3.10 + def reveal_type(obj: T, /) -> T: """Reveal the inferred type of a variable. @@ -2569,6 +2721,7 @@ def reveal_type(obj: T, /) -> T: if hasattr(typing, "assert_never"): # 3.11+ assert_never = typing.assert_never else: # <=3.10 + def assert_never(arg: Never, /) -> Never: """Assert to the type checker that a line of code is unreachable. @@ -2591,7 +2744,7 @@ def int_or_str(arg: int | str) -> None: """ value = repr(arg) if len(value) > _ASSERT_NEVER_REPR_MAX_LENGTH: - value = value[:_ASSERT_NEVER_REPR_MAX_LENGTH] + '...' + value = value[:_ASSERT_NEVER_REPR_MAX_LENGTH] + "..." raise AssertionError(f"Expected code to be unreachable, but got: {value}") @@ -2599,6 +2752,7 @@ def int_or_str(arg: int | str) -> None: # dataclass_transform exists in 3.11 but lacks the frozen_default parameter dataclass_transform = typing.dataclass_transform else: # <=3.11 + def dataclass_transform( *, eq_default: bool = True, @@ -2606,8 +2760,7 @@ def dataclass_transform( kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: typing.Tuple[ - typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], - ... + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], ... ] = (), **kwargs: typing.Any, ) -> typing.Callable[[T], T]: @@ -2672,6 +2825,7 @@ class CustomerModel(ModelBase): See PEP 681 for details. """ + def decorator(cls_or_fn): cls_or_fn.__dataclass_transform__ = { "eq_default": eq_default, @@ -2682,6 +2836,7 @@ def decorator(cls_or_fn): "kwargs": kwargs, } return cls_or_fn + return decorator @@ -2773,6 +2928,7 @@ def g(x: str) -> int: ... See PEP 702 for details. """ + def __init__( self, message: str, @@ -2834,6 +2990,7 @@ def __init_subclass__(*args, **kwargs): # Or otherwise, which likely means it's a builtin such as # object's implementation of __init_subclass__. else: + @functools.wraps(original_init_subclass) def __init_subclass__(*args, **kwargs): warnings.warn(msg, category=category, stacklevel=stacklevel + 1) @@ -2869,6 +3026,7 @@ def wrapper(*args, **kwargs): # counting generic parameters, so that when we subscript a generic, # the runtime doesn't try to substitute the Unpack with the subscripted type. if not hasattr(typing, "TypeVarTuple"): + def _check_generic(cls, parameters, elen=_marker): """Check correct count for parameters of a generic cls (internal helper). @@ -2895,21 +3053,26 @@ def _check_generic(cls, parameters, elen=_marker): # since we validate TypeVarLike default in _collect_type_vars # or _collect_parameters we can safely check parameters[alen] if ( - getattr(parameters[alen], '__default__', NoDefault) + getattr(parameters[alen], "__default__", NoDefault) is not NoDefault ): return - num_default_tv = sum(getattr(p, '__default__', NoDefault) - is not NoDefault for p in parameters) + num_default_tv = sum( + getattr(p, "__default__", NoDefault) is not NoDefault + for p in parameters + ) elen -= num_default_tv expect_val = f"at least {elen}" things = "arguments" if sys.version_info >= (3, 10) else "parameters" - raise TypeError(f"Too {'many' if alen > elen else 'few'} {things}" - f" for {cls}; actual {alen}, expected {expect_val}") + raise TypeError( + f"Too {'many' if alen > elen else 'few'} {things}" + f" for {cls}; actual {alen}, expected {expect_val}" + ) + else: # Python 3.11+ @@ -2932,20 +3095,25 @@ def _check_generic(cls, parameters, elen): # since we validate TypeVarLike default in _collect_type_vars # or _collect_parameters we can safely check parameters[alen] if ( - getattr(parameters[alen], '__default__', NoDefault) + getattr(parameters[alen], "__default__", NoDefault) is not NoDefault ): return - num_default_tv = sum(getattr(p, '__default__', NoDefault) - is not NoDefault for p in parameters) + num_default_tv = sum( + getattr(p, "__default__", NoDefault) is not NoDefault + for p in parameters + ) elen -= num_default_tv expect_val = f"at least {elen}" - raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments" - f" for {cls}; actual {alen}, expected {expect_val}") + raise TypeError( + f"Too {'many' if alen > elen else 'few'} arguments" + f" for {cls}; actual {alen}, expected {expect_val}" + ) + if not _PEP_696_IMPLEMENTED: typing._check_generic = _check_generic @@ -2967,7 +3135,9 @@ def _has_generic_or_protocol_as_origin() -> bool: origin = frame.f_locals.get("origin") # Cannot use "in" because origin may be an object with a buggy __eq__ that # throws an error. - return origin is typing.Generic or origin is Protocol or origin is typing.Protocol + return ( + origin is typing.Generic or origin is Protocol or origin is typing.Protocol + ) _TYPEVARTUPLE_TYPES = {TypeVarTuple, getattr(typing, "TypeVarTuple", None)} @@ -2977,15 +3147,12 @@ def _is_unpacked_typevartuple(x) -> bool: if get_origin(x) is not Unpack: return False args = get_args(x) - return ( - bool(args) - and len(args) == 1 - and type(args[0]) in _TYPEVARTUPLE_TYPES - ) + return bool(args) and len(args) == 1 and type(args[0]) in _TYPEVARTUPLE_TYPES # Python 3.11+ _collect_type_vars was renamed to _collect_parameters -if hasattr(typing, '_collect_type_vars'): +if hasattr(typing, "_collect_type_vars"): + def _collect_type_vars(types, typevar_types=None): """Collect all type variable contained in types in order of first appearance (lexicographic order). For example:: @@ -3009,15 +3176,18 @@ def _collect_type_vars(types, typevar_types=None): type_var_tuple_encountered = True elif isinstance(t, typevar_types) and t not in tvars: if enforce_default_ordering: - has_default = getattr(t, '__default__', NoDefault) is not NoDefault + has_default = getattr(t, "__default__", NoDefault) is not NoDefault if has_default: if type_var_tuple_encountered: - raise TypeError('Type parameter with a default' - ' follows TypeVarTuple') + raise TypeError( + "Type parameter with a default" " follows TypeVarTuple" + ) default_encountered = True elif default_encountered: - raise TypeError(f'Type parameter {t!r} without a default' - ' follows type parameter with a default') + raise TypeError( + f"Type parameter {t!r} without a default" + " follows type parameter with a default" + ) tvars.append(t) if _should_collect_from_parameters(t): @@ -3026,6 +3196,7 @@ def _collect_type_vars(types, typevar_types=None): typing._collect_type_vars = _collect_type_vars else: + def _collect_parameters(args): """Collect all type variables and parameter specifications in args in order of first appearance (lexicographic order). @@ -3055,28 +3226,31 @@ def _collect_parameters(args): for collected in _collect_parameters([x]): if collected not in parameters: parameters.append(collected) - elif hasattr(t, '__typing_subst__'): + elif hasattr(t, "__typing_subst__"): if t not in parameters: if enforce_default_ordering: has_default = ( - getattr(t, '__default__', NoDefault) is not NoDefault + getattr(t, "__default__", NoDefault) is not NoDefault ) if type_var_tuple_encountered and has_default: - raise TypeError('Type parameter with a default' - ' follows TypeVarTuple') + raise TypeError( + "Type parameter with a default" " follows TypeVarTuple" + ) if has_default: default_encountered = True elif default_encountered: - raise TypeError(f'Type parameter {t!r} without a default' - ' follows type parameter with a default') + raise TypeError( + f"Type parameter {t!r} without a default" + " follows type parameter with a default" + ) parameters.append(t) else: if _is_unpacked_typevartuple(t): type_var_tuple_encountered = True - for x in getattr(t, '__parameters__', ()): + for x in getattr(t, "__parameters__", ()): if x not in parameters: parameters.append(x) @@ -3093,12 +3267,14 @@ def _collect_parameters(args): if sys.version_info >= (3, 13): NamedTuple = typing.NamedTuple else: + def _make_nmtuple(name, types, module, defaults=()): fields = [n for n, t in types] - annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") - for n, t in types} - nm_tpl = collections.namedtuple(name, fields, - defaults=defaults, module=module) + annotations = { + n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types + } + nm_tpl = collections.namedtuple(name, fields, defaults=defaults, module=module) nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations # The `_field_types` attribute was removed in 3.9; # in earlier versions, it is the same as the `__annotations__` attribute @@ -3107,7 +3283,9 @@ def _make_nmtuple(name, types, module, defaults=()): return nm_tpl _prohibited_namedtuple_fields = typing._prohibited - _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + _special_namedtuple_fields = frozenset( + {"__module__", "__name__", "__annotations__"} + ) class _NamedTupleMeta(type): def __new__(cls, typename, bases, ns): @@ -3115,7 +3293,8 @@ def __new__(cls, typename, bases, ns): for base in bases: if base is not _NamedTuple and base is not typing.Generic: raise TypeError( - 'can only inherit from a NamedTuple type and Generic') + "can only inherit from a NamedTuple type and Generic" + ) bases = tuple(tuple if base is _NamedTuple else base for base in bases) if "__annotations__" in ns: types = ns["__annotations__"] @@ -3129,19 +3308,24 @@ def __new__(cls, typename, bases, ns): if field_name in ns: default_names.append(field_name) elif default_names: - raise TypeError(f"Non-default namedtuple field {field_name} " - f"cannot follow default field" - f"{'s' if len(default_names) > 1 else ''} " - f"{', '.join(default_names)}") + raise TypeError( + f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}" + ) nm_tpl = _make_nmtuple( - typename, types.items(), + typename, + types.items(), defaults=[ns[n] for n in default_names], - module=ns['__module__'] + module=ns["__module__"], ) nm_tpl.__bases__ = bases if typing.Generic in bases: - if hasattr(typing, '_generic_class_getitem'): # 3.12+ - nm_tpl.__class_getitem__ = classmethod(typing._generic_class_getitem) + if hasattr(typing, "_generic_class_getitem"): # 3.12+ + nm_tpl.__class_getitem__ = classmethod( + typing._generic_class_getitem + ) else: class_getitem = typing.Generic.__class_getitem__.__func__ nm_tpl.__class_getitem__ = classmethod(class_getitem) @@ -3179,7 +3363,7 @@ def __new__(cls, typename, bases, ns): nm_tpl.__init_subclass__() return nm_tpl - _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + _NamedTuple = type.__new__(_NamedTupleMeta, "NamedTuple", (), {}) def _namedtuple_mro_entries(bases): assert NamedTuple in bases @@ -3217,11 +3401,15 @@ class Employee(NamedTuple): deprecated_thing = "Failing to pass a value for the 'fields' parameter" example = f"`{typename} = NamedTuple({typename!r}, [])`" deprecation_msg = ( - "{name} is deprecated and will be disallowed in Python {remove}. " - "To create a NamedTuple class with 0 fields " - "using the functional syntax, " - "pass an empty list, e.g. " - ) + example + "." + ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + + example + + "." + ) elif fields is None: if kwargs: raise TypeError( @@ -3232,14 +3420,20 @@ class Employee(NamedTuple): deprecated_thing = "Passing `None` as the 'fields' parameter" example = f"`{typename} = NamedTuple({typename!r}, [])`" deprecation_msg = ( - "{name} is deprecated and will be disallowed in Python {remove}. " - "To create a NamedTuple class with 0 fields " - "using the functional syntax, " - "pass an empty list, e.g. " - ) + example + "." + ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + + example + + "." + ) elif kwargs: - raise TypeError("Either list of fields or keywords" - " can be provided to NamedTuple, not both") + raise TypeError( + "Either list of fields or keywords" + " can be provided to NamedTuple, not both" + ) if fields is _marker or fields is None: warnings.warn( deprecation_msg.format(name=deprecated_thing, remove="3.15"), @@ -3255,6 +3449,7 @@ class Employee(NamedTuple): if hasattr(collections.abc, "Buffer"): Buffer = collections.abc.Buffer else: + class Buffer(abc.ABC): # noqa: B024 """Base class for classes that implement the buffer protocol. @@ -3285,6 +3480,7 @@ class Buffer(abc.ABC): # noqa: B024 if hasattr(_types, "get_original_bases"): get_original_bases = _types.get_original_bases else: + def get_original_bases(cls, /): """Return the class's "original" bases prior to modification by `__mro_entries__`. @@ -3310,7 +3506,7 @@ class Baz(list[str]): ... return cls.__dict__.get("__orig_bases__", cls.__bases__) except AttributeError: raise TypeError( - f'Expected an instance of type, not {type(cls).__name__!r}' + f"Expected an instance of type, not {type(cls).__name__!r}" ) from None @@ -3319,6 +3515,7 @@ class Baz(list[str]): ... if sys.version_info >= (3, 11): NewType = typing.NewType else: + class NewType: """NewType creates simple unique types with almost zero runtime overhead. NewType(name, tp) is considered a subtype of tp @@ -3338,12 +3535,12 @@ def __call__(self, obj, /): def __init__(self, name, tp): self.__qualname__ = name - if '.' in name: - name = name.rpartition('.')[-1] + if "." in name: + name = name.rpartition(".")[-1] self.__name__ = name self.__supertype__ = tp def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod def __mro_entries__(self, bases): @@ -3363,7 +3560,7 @@ def __init_subclass__(cls): return (Dummy,) def __repr__(self): - return f'{self.__module__}.{self.__qualname__}' + return f"{self.__module__}.{self.__qualname__}" def __reduce__(self): return self.__qualname__ @@ -3382,14 +3579,18 @@ def __ror__(self, other): if hasattr(typing, "TypeAliasType"): TypeAliasType = typing.TypeAliasType else: + def _is_unionable(obj): """Corresponds to is_unionable() in unionobject.c in CPython.""" - return obj is None or isinstance(obj, ( - type, - _types.GenericAlias, - _types.UnionType, - TypeAliasType, - )) + return obj is None or isinstance( + obj, + ( + type, + _types.GenericAlias, + _types.UnionType, + TypeAliasType, + ), + ) class TypeAliasType: """Create named, parameterized type aliases. @@ -3433,7 +3634,7 @@ def __init__(self, name: str, value, *, type_params=()): parameters.append(type_param) self.__parameters__ = tuple(parameters) def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod # Setting this attribute closes the TypeAliasType from further modification self.__name__ = name @@ -3450,7 +3651,12 @@ def _raise_attribute_error(self, name: str) -> Never: # Match the Python 3.12 error messages exactly if name == "__name__": raise AttributeError("readonly attribute") - elif name in {"__value__", "__type_params__", "__parameters__", "__module__"}: + elif name in { + "__value__", + "__type_params__", + "__parameters__", + "__module__", + }: raise AttributeError( f"attribute '{name}' of 'typing.TypeAliasType' objects " "is not writable" @@ -3468,7 +3674,7 @@ def __getitem__(self, parameters): parameters = (parameters,) parameters = [ typing._type_check( - item, f'Subscripting {self.__name__} requires a type.' + item, f"Subscripting {self.__name__} requires a type." ) for item in parameters ] @@ -3488,6 +3694,7 @@ def __call__(self): raise TypeError("Type alias is not callable") if sys.version_info >= (3, 10): + def __or__(self, right): # For forward compatibility with 3.12, reject Unions # that are not accepted by the built-in Union. @@ -3505,6 +3712,7 @@ def __ror__(self, left): is_protocol = typing.is_protocol get_protocol_members = typing.get_protocol_members else: + def is_protocol(tp: type, /) -> bool: """Return True if the given type is a Protocol. @@ -3521,7 +3729,7 @@ def is_protocol(tp: type, /) -> bool: """ return ( isinstance(tp, type) - and getattr(tp, '_is_protocol', False) + and getattr(tp, "_is_protocol", False) and tp is not Protocol and tp is not typing.Protocol ) @@ -3541,8 +3749,8 @@ def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: Raise a TypeError for arguments that are not Protocols. """ if not is_protocol(tp): - raise TypeError(f'{tp!r} is not a Protocol') - if hasattr(tp, '__protocol_attrs__'): + raise TypeError(f"{tp!r} is not a Protocol") + if hasattr(tp, "__protocol_attrs__"): return frozenset(tp.__protocol_attrs__) return frozenset(_get_protocol_attrs(tp)) @@ -3550,6 +3758,7 @@ def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: if hasattr(typing, "Doc"): Doc = typing.Doc else: + class Doc: """Define the documentation of a type annotation using ``Annotated``, to be used in class attributes, function and method parameters, return values, @@ -3567,6 +3776,7 @@ class Doc: >>> from typing_extensions import Annotated, Doc >>> def hi(to: Annotated[str, Doc("Who to say hi to")]) -> None: ... """ + def __init__(self, documentation: str, /) -> None: self.documentation = documentation diff --git a/metaflow/_vendor/v3_6/__init__.py b/metaflow/_vendor/v3_6/__init__.py index 22ae0c5f40e..932b79829cf 100644 --- a/metaflow/_vendor/v3_6/__init__.py +++ b/metaflow/_vendor/v3_6/__init__.py @@ -1 +1 @@ -# Empty file \ No newline at end of file +# Empty file diff --git a/metaflow/_vendor/v3_6/importlib_metadata/__init__.py b/metaflow/_vendor/v3_6/importlib_metadata/__init__.py index 8d3b7814d50..10a743115e4 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/__init__.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/__init__.py @@ -33,18 +33,18 @@ __all__ = [ - 'Distribution', - 'DistributionFinder', - 'PackageMetadata', - 'PackageNotFoundError', - 'distribution', - 'distributions', - 'entry_points', - 'files', - 'metadata', - 'packages_distributions', - 'requires', - 'version', + "Distribution", + "DistributionFinder", + "PackageMetadata", + "PackageNotFoundError", + "distribution", + "distributions", + "entry_points", + "files", + "metadata", + "packages_distributions", + "requires", + "version", ] @@ -114,15 +114,15 @@ def read(text, filter_=None): lines = filter(filter_, map(str.strip, text.splitlines())) name = None for value in lines: - section_match = value.startswith('[') and value.endswith(']') + section_match = value.startswith("[") and value.endswith("]") if section_match: - name = value.strip('[]') + name = value.strip("[]") continue yield Pair(name, value) @staticmethod def valid(line): - return line and not line.startswith('#') + return line and not line.startswith("#") class DeprecatedTuple: @@ -160,9 +160,9 @@ class EntryPoint(DeprecatedTuple): """ pattern = re.compile( - r'(?P[\w.]+)\s*' - r'(:\s*(?P[\w.]+))?\s*' - r'(?P\[.*\])?\s*$' + r"(?P[\w.]+)\s*" + r"(:\s*(?P[\w.]+))?\s*" + r"(?P\[.*\])?\s*$" ) """ A regular expression describing the syntax for an entry point, @@ -180,7 +180,7 @@ class EntryPoint(DeprecatedTuple): following the attr, and following any extras. """ - dist: Optional['Distribution'] = None + dist: Optional["Distribution"] = None def __init__(self, name, value, group): vars(self).update(name=name, value=value, group=group) @@ -191,24 +191,24 @@ def load(self): return the named object. """ match = self.pattern.match(self.value) - module = import_module(match.group('module')) - attrs = filter(None, (match.group('attr') or '').split('.')) + module = import_module(match.group("module")) + attrs = filter(None, (match.group("attr") or "").split(".")) return functools.reduce(getattr, attrs, module) @property def module(self): match = self.pattern.match(self.value) - return match.group('module') + return match.group("module") @property def attr(self): match = self.pattern.match(self.value) - return match.group('attr') + return match.group("attr") @property def extras(self): match = self.pattern.match(self.value) - return list(re.finditer(r'\w+', match.group('extras') or '')) + return list(re.finditer(r"\w+", match.group("extras") or "")) def _for(self, dist): vars(self).update(dist=dist) @@ -243,8 +243,8 @@ def __setattr__(self, name, value): def __repr__(self): return ( - f'EntryPoint(name={self.name!r}, value={self.value!r}, ' - f'group={self.group!r})' + f"EntryPoint(name={self.name!r}, value={self.value!r}, " + f"group={self.group!r})" ) def __hash__(self): @@ -298,16 +298,16 @@ def wrapped(self, *args, **kwargs): return wrapped for method_name in [ - '__setitem__', - '__delitem__', - 'append', - 'reverse', - 'extend', - 'pop', - 'remove', - '__iadd__', - 'insert', - 'sort', + "__setitem__", + "__delitem__", + "append", + "reverse", + "extend", + "pop", + "remove", + "__iadd__", + "insert", + "sort", ]: locals()[method_name] = _wrap_deprecated_method(method_name) @@ -382,7 +382,7 @@ def _from_text_for(cls, text, dist): def _from_text(text): return ( EntryPoint(name=item.value.name, value=item.value.value, group=item.name) - for item in Sectioned.section_pairs(text or '') + for item in Sectioned.section_pairs(text or "") ) @@ -449,7 +449,7 @@ class SelectableGroups(Deprecated, dict): @classmethod def load(cls, eps): - by_group = operator.attrgetter('group') + by_group = operator.attrgetter("group") ordered = sorted(eps, key=by_group) grouped = itertools.groupby(ordered, by_group) return cls((group, EntryPoints(eps)) for group, eps in grouped) @@ -484,12 +484,12 @@ def select(self, **params): class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" - def read_text(self, encoding='utf-8'): + def read_text(self, encoding="utf-8"): with self.locate().open(encoding=encoding) as stream: return stream.read() def read_binary(self): - with self.locate().open('rb') as stream: + with self.locate().open("rb") as stream: return stream.read() def locate(self): @@ -499,10 +499,10 @@ def locate(self): class FileHash: def __init__(self, spec): - self.mode, _, self.value = spec.partition('=') + self.mode, _, self.value = spec.partition("=") def __repr__(self): - return f'' + return f"" class Distribution: @@ -551,7 +551,7 @@ def discover(cls, **kwargs): :context: A ``DistributionFinder.Context`` object. :return: Iterable of Distribution objects for all packages. """ - context = kwargs.pop('context', None) + context = kwargs.pop("context", None) if context and kwargs: raise ValueError("cannot accept context and kwargs") context = context or DistributionFinder.Context(**kwargs) @@ -572,12 +572,12 @@ def at(path): def _discover_resolvers(): """Search the meta_path for resolvers.""" declared = ( - getattr(finder, 'find_distributions', None) for finder in sys.meta_path + getattr(finder, "find_distributions", None) for finder in sys.meta_path ) return filter(None, declared) @classmethod - def _local(cls, root='.'): + def _local(cls, root="."): from pep517 import build, meta system = build.compat_system(root) @@ -596,19 +596,19 @@ def metadata(self) -> _meta.PackageMetadata: metadata. See PEP 566 for details. """ text = ( - self.read_text('METADATA') - or self.read_text('PKG-INFO') + self.read_text("METADATA") + or self.read_text("PKG-INFO") # This last clause is here to support old egg-info files. Its # effect is to just end up using the PathDistribution's self._path # (which points to the egg-info file) attribute unchanged. - or self.read_text('') + or self.read_text("") ) return _adapters.Message(email.message_from_string(text)) @property def name(self): """Return the 'Name' metadata for the distribution package.""" - return self.metadata['Name'] + return self.metadata["Name"] @property def _normalized_name(self): @@ -618,11 +618,11 @@ def _normalized_name(self): @property def version(self): """Return the 'Version' metadata for the distribution package.""" - return self.metadata['Version'] + return self.metadata["Version"] @property def entry_points(self): - return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self) + return EntryPoints._from_text_for(self.read_text("entry_points.txt"), self) @property def files(self): @@ -653,7 +653,7 @@ def _read_files_distinfo(self): """ Read the lines of RECORD """ - text = self.read_text('RECORD') + text = self.read_text("RECORD") return text and text.splitlines() def _read_files_egginfo(self): @@ -661,7 +661,7 @@ def _read_files_egginfo(self): SOURCES.txt might contain literal commas, so wrap each line in quotes. """ - text = self.read_text('SOURCES.txt') + text = self.read_text("SOURCES.txt") return text and map('"{}"'.format, text.splitlines()) @property @@ -671,10 +671,10 @@ def requires(self): return reqs and list(reqs) def _read_dist_info_reqs(self): - return self.metadata.get_all('Requires-Dist') + return self.metadata.get_all("Requires-Dist") def _read_egg_info_reqs(self): - source = self.read_text('requires.txt') + source = self.read_text("requires.txt") return source and self._deps_from_requires_text(source) @classmethod @@ -697,12 +697,12 @@ def make_condition(name): return name and f'extra == "{name}"' def quoted_marker(section): - section = section or '' - extra, sep, markers = section.partition(':') + section = section or "" + extra, sep, markers = section.partition(":") if extra and markers: - markers = f'({markers})' + markers = f"({markers})" conditions = list(filter(None, [markers, make_condition(extra)])) - return '; ' + ' and '.join(conditions) if conditions else '' + return "; " + " and ".join(conditions) if conditions else "" def url_req_space(req): """ @@ -710,7 +710,7 @@ def url_req_space(req): Ref python/importlib_metadata#357. """ # '@' is uniquely indicative of a url_req. - return ' ' * ('@' in req) + return " " * ("@" in req) for section in sections: space = url_req_space(section.value) @@ -752,7 +752,7 @@ def path(self): Typically refers to Python installed package paths such as "site-packages" directories and defaults to ``sys.path``. """ - return vars(self).get('path', sys.path) + return vars(self).get("path", sys.path) @abc.abstractmethod def find_distributions(self, context=Context()): @@ -786,7 +786,7 @@ def joinpath(self, child): def children(self): with suppress(Exception): - return os.listdir(self.root or '.') + return os.listdir(self.root or ".") with suppress(Exception): return self.zip_children() return [] @@ -868,7 +868,7 @@ def normalize(name): """ PEP 503 normalization plus dashes as underscores. """ - return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_') + return re.sub(r"[-_.]+", "-", name).lower().replace("-", "_") @staticmethod def legacy_normalize(name): @@ -876,7 +876,7 @@ def legacy_normalize(name): Normalize the package name as found in the convention in older packaging tools versions and specs. """ - return name.lower().replace('-', '_') + return name.lower().replace("-", "_") def __bool__(self): return bool(self.name) @@ -930,7 +930,7 @@ def read_text(self, filename): NotADirectoryError, PermissionError, ): - return self._path.joinpath(filename).read_text(encoding='utf-8') + return self._path.joinpath(filename).read_text(encoding="utf-8") read_text.__doc__ = Distribution.read_text.__doc__ @@ -948,9 +948,9 @@ def _normalized_name(self): def _name_from_stem(self, stem): name, ext = os.path.splitext(stem) - if ext not in ('.dist-info', '.egg-info'): + if ext not in (".dist-info", ".egg-info"): return - name, sep, rest = stem.partition('-') + name, sep, rest = stem.partition("-") return name @@ -1007,7 +1007,7 @@ def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: :return: EntryPoints or SelectableGroups for all installed packages. """ - norm_name = operator.attrgetter('_normalized_name') + norm_name = operator.attrgetter("_normalized_name") unique = functools.partial(unique_everseen, key=norm_name) eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) @@ -1047,17 +1047,17 @@ def packages_distributions() -> Mapping[str, List[str]]: pkg_to_dist = collections.defaultdict(list) for dist in distributions(): for pkg in _top_level_declared(dist) or _top_level_inferred(dist): - pkg_to_dist[pkg].append(dist.metadata['Name']) + pkg_to_dist[pkg].append(dist.metadata["Name"]) return dict(pkg_to_dist) def _top_level_declared(dist): - return (dist.read_text('top_level.txt') or '').split() + return (dist.read_text("top_level.txt") or "").split() def _top_level_inferred(dist): return { - f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + f.parts[0] if len(f.parts) > 1 else f.with_suffix("").name for f in always_iterable(dist.files) if f.suffix == ".py" } diff --git a/metaflow/_vendor/v3_6/importlib_metadata/_adapters.py b/metaflow/_vendor/v3_6/importlib_metadata/_adapters.py index aa460d3eda5..49cfa02e666 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/_adapters.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/_adapters.py @@ -10,16 +10,16 @@ class Message(email.message.Message): map( FoldedCase, [ - 'Classifier', - 'Obsoletes-Dist', - 'Platform', - 'Project-URL', - 'Provides-Dist', - 'Provides-Extra', - 'Requires-Dist', - 'Requires-External', - 'Supported-Platform', - 'Dynamic', + "Classifier", + "Obsoletes-Dist", + "Platform", + "Project-URL", + "Provides-Dist", + "Provides-Extra", + "Requires-Dist", + "Requires-External", + "Supported-Platform", + "Dynamic", ], ) ) @@ -42,13 +42,13 @@ def __iter__(self): def _repair_headers(self): def redent(value): "Correct for RFC822 indentation" - if not value or '\n' not in value: + if not value or "\n" not in value: return value - return textwrap.dedent(' ' * 8 + value) + return textwrap.dedent(" " * 8 + value) - headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + headers = [(key, redent(value)) for key, value in vars(self)["_headers"]] if self._payload: - headers.append(('Description', self.get_payload())) + headers.append(("Description", self.get_payload())) return headers @property @@ -60,9 +60,9 @@ def json(self): def transform(key): value = self.get_all(key) if key in self.multiple_use_keys else self[key] - if key == 'Keywords': - value = re.split(r'\s+', value) - tk = key.lower().replace('-', '_') + if key == "Keywords": + value = re.split(r"\s+", value) + tk = key.lower().replace("-", "_") return tk, value return dict(map(transform, map(FoldedCase, self))) diff --git a/metaflow/_vendor/v3_6/importlib_metadata/_collections.py b/metaflow/_vendor/v3_6/importlib_metadata/_collections.py index cf0954e1a30..895678a23c3 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/_collections.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/_collections.py @@ -18,13 +18,13 @@ class FreezableDefaultDict(collections.defaultdict): """ def __missing__(self, key): - return getattr(self, '_frozen', super().__missing__)(key) + return getattr(self, "_frozen", super().__missing__)(key) def freeze(self): self._frozen = lambda key: self.default_factory() -class Pair(collections.namedtuple('Pair', 'name value')): +class Pair(collections.namedtuple("Pair", "name value")): @classmethod def parse(cls, text): return cls(*map(str.strip, text.split("=", 1))) diff --git a/metaflow/_vendor/v3_6/importlib_metadata/_compat.py b/metaflow/_vendor/v3_6/importlib_metadata/_compat.py index 3680940f0b0..eaaab2ffa62 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/_compat.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/_compat.py @@ -2,7 +2,7 @@ import platform -__all__ = ['install', 'NullFinder', 'Protocol'] +__all__ = ["install", "NullFinder", "Protocol"] try: @@ -35,8 +35,8 @@ def disable_stdlib_finder(): def matches(finder): return getattr( - finder, '__module__', None - ) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions') + finder, "__module__", None + ) == "_frozen_importlib_external" and hasattr(finder, "find_distributions") for finder in filter(matches, sys.meta_path): # pragma: nocover del finder.find_distributions @@ -67,5 +67,5 @@ def pypy_partial(val): Workaround for #327. """ - is_pypy = platform.python_implementation() == 'PyPy' + is_pypy = platform.python_implementation() == "PyPy" return val + is_pypy diff --git a/metaflow/_vendor/v3_6/importlib_metadata/_meta.py b/metaflow/_vendor/v3_6/importlib_metadata/_meta.py index 37ee43e6ef4..bc502995bd7 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/_meta.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/_meta.py @@ -6,17 +6,13 @@ class PackageMetadata(Protocol): - def __len__(self) -> int: - ... # pragma: no cover + def __len__(self) -> int: ... # pragma: no cover - def __contains__(self, item: str) -> bool: - ... # pragma: no cover + def __contains__(self, item: str) -> bool: ... # pragma: no cover - def __getitem__(self, key: str) -> str: - ... # pragma: no cover + def __getitem__(self, key: str) -> str: ... # pragma: no cover - def __iter__(self) -> Iterator[str]: - ... # pragma: no cover + def __iter__(self) -> Iterator[str]: ... # pragma: no cover def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: """ @@ -35,14 +31,10 @@ class SimplePath(Protocol): A minimal subset of pathlib.Path required by PathDistribution. """ - def joinpath(self) -> 'SimplePath': - ... # pragma: no cover + def joinpath(self) -> "SimplePath": ... # pragma: no cover - def __truediv__(self) -> 'SimplePath': - ... # pragma: no cover + def __truediv__(self) -> "SimplePath": ... # pragma: no cover - def parent(self) -> 'SimplePath': - ... # pragma: no cover + def parent(self) -> "SimplePath": ... # pragma: no cover - def read_text(self) -> str: - ... # pragma: no cover + def read_text(self) -> str: ... # pragma: no cover diff --git a/metaflow/_vendor/v3_6/importlib_metadata/_text.py b/metaflow/_vendor/v3_6/importlib_metadata/_text.py index c88cfbb2349..376210d7096 100644 --- a/metaflow/_vendor/v3_6/importlib_metadata/_text.py +++ b/metaflow/_vendor/v3_6/importlib_metadata/_text.py @@ -94,6 +94,6 @@ def lower(self): def index(self, sub): return self.lower().index(sub.lower()) - def split(self, splitter=' ', maxsplit=0): + def split(self, splitter=" ", maxsplit=0): pattern = re.compile(re.escape(splitter), re.I) return pattern.split(self, maxsplit) diff --git a/metaflow/_vendor/v3_6/typing_extensions.py b/metaflow/_vendor/v3_6/typing_extensions.py index 43c05bdcd22..071d0436c41 100644 --- a/metaflow/_vendor/v3_6/typing_extensions.py +++ b/metaflow/_vendor/v3_6/typing_extensions.py @@ -21,58 +21,54 @@ # Please keep __all__ alphabetized within each category. __all__ = [ # Super-special typing primitives. - 'ClassVar', - 'Concatenate', - 'Final', - 'LiteralString', - 'ParamSpec', - 'Self', - 'Type', - 'TypeVarTuple', - 'Unpack', - + "ClassVar", + "Concatenate", + "Final", + "LiteralString", + "ParamSpec", + "Self", + "Type", + "TypeVarTuple", + "Unpack", # ABCs (from collections.abc). - 'Awaitable', - 'AsyncIterator', - 'AsyncIterable', - 'Coroutine', - 'AsyncGenerator', - 'AsyncContextManager', - 'ChainMap', - + "Awaitable", + "AsyncIterator", + "AsyncIterable", + "Coroutine", + "AsyncGenerator", + "AsyncContextManager", + "ChainMap", # Concrete collection types. - 'ContextManager', - 'Counter', - 'Deque', - 'DefaultDict', - 'OrderedDict', - 'TypedDict', - + "ContextManager", + "Counter", + "Deque", + "DefaultDict", + "OrderedDict", + "TypedDict", # Structural checks, a.k.a. protocols. - 'SupportsIndex', - + "SupportsIndex", # One-off things. - 'Annotated', - 'assert_never', - 'dataclass_transform', - 'final', - 'IntVar', - 'is_typeddict', - 'Literal', - 'NewType', - 'overload', - 'Protocol', - 'reveal_type', - 'runtime', - 'runtime_checkable', - 'Text', - 'TypeAlias', - 'TypeGuard', - 'TYPE_CHECKING', - 'Never', - 'NoReturn', - 'Required', - 'NotRequired', + "Annotated", + "assert_never", + "dataclass_transform", + "final", + "IntVar", + "is_typeddict", + "Literal", + "NewType", + "overload", + "Protocol", + "reveal_type", + "runtime", + "runtime_checkable", + "Text", + "TypeAlias", + "TypeGuard", + "TYPE_CHECKING", + "Never", + "NoReturn", + "Required", + "NotRequired", ] if PEP_560: @@ -84,8 +80,8 @@ def _no_slots_copy(dct): dict_copy = dict(dct) - if '__slots__' in dict_copy: - for slot in dict_copy['__slots__']: + if "__slots__" in dict_copy: + for slot in dict_copy["__slots__"]: dict_copy.pop(slot, None) return dict_copy @@ -110,19 +106,26 @@ def _check_generic(cls, parameters, elen=_marker): num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): return - raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" - f" actual {alen}, expected {elen}") + raise TypeError( + f"Too {'many' if alen > elen else 'few'} parameters for {cls};" + f" actual {alen}, expected {elen}" + ) if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): return isinstance( t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) ) + elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) + else: + def _should_collect_from_parameters(t): return isinstance(t, typing._GenericAlias) and not t._special @@ -137,11 +140,7 @@ def _collect_type_vars(types, typevar_types=None): typevar_types = typing.TypeVar tvars = [] for t in types: - if ( - isinstance(t, typevar_types) and - t not in tvars and - not _is_unpack(t) - ): + if isinstance(t, typevar_types) and t not in tvars and not _is_unpack(t): tvars.append(t) if _should_collect_from_parameters(t): tvars.extend([t for t in t.__parameters__ if t not in tvars]) @@ -149,10 +148,11 @@ def _collect_type_vars(types, typevar_types=None): # 3.6.2+ -if hasattr(typing, 'NoReturn'): +if hasattr(typing, "NoReturn"): NoReturn = typing.NoReturn # 3.6.0-3.6.1 else: + class _NoReturn(typing._FinalTypingBase, _root=True): """Special type indicating functions that never return. Example:: @@ -165,6 +165,7 @@ def stop() -> NoReturn: This type is invalid in other positions, e.g., ``List[NoReturn]`` will fail in static type checkers. """ + __slots__ = () def __instancecheck__(self, obj): @@ -177,32 +178,35 @@ def __subclasscheck__(self, cls): # Some unconstrained type variables. These are used by the container types. # (These are not for export.) -T = typing.TypeVar('T') # Any type. -KT = typing.TypeVar('KT') # Key type. -VT = typing.TypeVar('VT') # Value type. -T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. -T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. +T = typing.TypeVar("T") # Any type. +KT = typing.TypeVar("KT") # Key type. +VT = typing.TypeVar("VT") # Value type. +T_co = typing.TypeVar("T_co", covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar("T_contra", contravariant=True) # Ditto contravariant. ClassVar = typing.ClassVar # On older versions of typing there is an internal class named "Final". # 3.8+ -if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): +if hasattr(typing, "Final") and sys.version_info[:2] >= (3, 7): Final = typing.Final # 3.7 elif sys.version_info[:2] >= (3, 7): + class _FinalForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only single type') + item = typing._type_check( + parameters, f"{self._name} accepts only single type" + ) return typing._GenericAlias(self, (item,)) - Final = _FinalForm('Final', - doc="""A special typing construct to indicate that a name + Final = _FinalForm( + "Final", + doc="""A special typing construct to indicate that a name cannot be re-assigned or overridden in a subclass. For example: @@ -214,9 +218,11 @@ class Connection: class FastConnector(Connection): TIMEOUT = 1 # Error reported by type checker - There is no runtime checking of these properties.""") + There is no runtime checking of these properties.""", + ) # 3.6 else: + class _Final(typing._FinalTypingBase, _root=True): """A special typing construct to indicate that a name cannot be re-assigned or overridden in a subclass. @@ -233,7 +239,7 @@ class FastConnector(Connection): There is no runtime checking of these properties. """ - __slots__ = ('__type__',) + __slots__ = ("__type__",) def __init__(self, tp=None, **kwds): self.__type__ = tp @@ -241,10 +247,13 @@ def __init__(self, tp=None, **kwds): def __getitem__(self, item): cls = type(self) if self.__type__ is None: - return cls(typing._type_check(item, - f'{cls.__name__[1:]} accepts only single type.'), - _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + return cls( + typing._type_check( + item, f"{cls.__name__[1:]} accepts only single type." + ), + _root=True, + ) + raise TypeError(f"{cls.__name__[1:]} cannot be further subscripted") def _eval_type(self, globalns, localns): new_tp = typing._eval_type(self.__type__, globalns, localns) @@ -255,7 +264,7 @@ def _eval_type(self, globalns, localns): def __repr__(self): r = super().__repr__() if self.__type__ is not None: - r += f'[{typing._type_repr(self.__type__)}]' + r += f"[{typing._type_repr(self.__type__)}]" return r def __hash__(self): @@ -314,20 +323,22 @@ def IntVar(name): # 3.8+: -if hasattr(typing, 'Literal'): +if hasattr(typing, "Literal"): Literal = typing.Literal # 3.7: elif sys.version_info[:2] >= (3, 7): + class _LiteralForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): return typing._GenericAlias(self, parameters) - Literal = _LiteralForm('Literal', - doc="""A type that can be used to indicate to type checkers + Literal = _LiteralForm( + "Literal", + doc="""A type that can be used to indicate to type checkers that the corresponding value has a value literally equivalent to the provided parameter. For example: @@ -338,9 +349,11 @@ def __getitem__(self, parameters): Literal[...] cannot be subclassed. There is no runtime checking verifying that the parameter is actually a value - instead of a type.""") + instead of a type.""", + ) # 3.6: else: + class _Literal(typing._FinalTypingBase, _root=True): """A type that can be used to indicate to type checkers that the corresponding value has a value literally equivalent to the @@ -355,7 +368,7 @@ class _Literal(typing._FinalTypingBase, _root=True): verifying that the parameter is actually a value instead of a type. """ - __slots__ = ('__values__',) + __slots__ = ("__values__",) def __init__(self, values=None, **kwds): self.__values__ = values @@ -366,7 +379,7 @@ def __getitem__(self, values): if not isinstance(values, tuple): values = (values,) return cls(values, _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + raise TypeError(f"{cls.__name__[1:]} cannot be further subscripted") def _eval_type(self, globalns, localns): return self @@ -409,9 +422,11 @@ def __subclasscheck__(self, subclass): versions of Python, see https://github.com/python/typing/issues/501. """ if self.__origin__ is not None: - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError("Parameterized generics cannot be used with class " - "or instance checks") + if sys._getframe(1).f_globals["__name__"] not in ["abc", "functools"]: + raise TypeError( + "Parameterized generics cannot be used with class " + "or instance checks" + ) return False if not self.__extra__: return super().__subclasscheck__(subclass) @@ -434,13 +449,17 @@ def __subclasscheck__(self, subclass): AsyncIterator = typing.AsyncIterator # 3.6.1+ -if hasattr(typing, 'Deque'): +if hasattr(typing, "Deque"): Deque = typing.Deque # 3.6.0 else: - class Deque(collections.deque, typing.MutableSequence[T], - metaclass=_ExtensionsGenericMeta, - extra=collections.deque): + + class Deque( + collections.deque, + typing.MutableSequence[T], + metaclass=_ExtensionsGenericMeta, + extra=collections.deque, + ): __slots__ = () def __new__(cls, *args, **kwds): @@ -448,9 +467,10 @@ def __new__(cls, *args, **kwds): return collections.deque(*args, **kwds) return typing._generic_new(collections.deque, cls, *args, **kwds) + ContextManager = typing.ContextManager # 3.6.2+ -if hasattr(typing, 'AsyncContextManager'): +if hasattr(typing, "AsyncContextManager"): AsyncContextManager = typing.AsyncContextManager # 3.6.0-3.6.1 else: @@ -472,19 +492,24 @@ def __subclasshook__(cls, C): return _check_methods_in_mro(C, "__aenter__", "__aexit__") return NotImplemented + DefaultDict = typing.DefaultDict # 3.7.2+ -if hasattr(typing, 'OrderedDict'): +if hasattr(typing, "OrderedDict"): OrderedDict = typing.OrderedDict # 3.7.0-3.7.2 elif (3, 7, 0) <= sys.version_info[:3] < (3, 7, 2): OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) # 3.6 else: - class OrderedDict(collections.OrderedDict, typing.MutableMapping[KT, VT], - metaclass=_ExtensionsGenericMeta, - extra=collections.OrderedDict): + + class OrderedDict( + collections.OrderedDict, + typing.MutableMapping[KT, VT], + metaclass=_ExtensionsGenericMeta, + extra=collections.OrderedDict, + ): __slots__ = () @@ -493,14 +518,19 @@ def __new__(cls, *args, **kwds): return collections.OrderedDict(*args, **kwds) return typing._generic_new(collections.OrderedDict, cls, *args, **kwds) + # 3.6.2+ -if hasattr(typing, 'Counter'): +if hasattr(typing, "Counter"): Counter = typing.Counter # 3.6.0-3.6.1 else: - class Counter(collections.Counter, - typing.Dict[T, int], - metaclass=_ExtensionsGenericMeta, extra=collections.Counter): + + class Counter( + collections.Counter, + typing.Dict[T, int], + metaclass=_ExtensionsGenericMeta, + extra=collections.Counter, + ): __slots__ = () @@ -509,13 +539,18 @@ def __new__(cls, *args, **kwds): return collections.Counter(*args, **kwds) return typing._generic_new(collections.Counter, cls, *args, **kwds) + # 3.6.1+ -if hasattr(typing, 'ChainMap'): +if hasattr(typing, "ChainMap"): ChainMap = typing.ChainMap -elif hasattr(collections, 'ChainMap'): - class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT], - metaclass=_ExtensionsGenericMeta, - extra=collections.ChainMap): +elif hasattr(collections, "ChainMap"): + + class ChainMap( + collections.ChainMap, + typing.MutableMapping[KT, VT], + metaclass=_ExtensionsGenericMeta, + extra=collections.ChainMap, + ): __slots__ = () @@ -524,16 +559,22 @@ def __new__(cls, *args, **kwds): return collections.ChainMap(*args, **kwds) return typing._generic_new(collections.ChainMap, cls, *args, **kwds) + # 3.6.1+ -if hasattr(typing, 'AsyncGenerator'): +if hasattr(typing, "AsyncGenerator"): AsyncGenerator = typing.AsyncGenerator # 3.6.0 else: - class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra], - metaclass=_ExtensionsGenericMeta, - extra=collections.abc.AsyncGenerator): + + class AsyncGenerator( + AsyncIterator[T_co], + typing.Generic[T_co, T_contra], + metaclass=_ExtensionsGenericMeta, + extra=collections.abc.AsyncGenerator, + ): __slots__ = () + NewType = typing.NewType Text = typing.Text TYPE_CHECKING = typing.TYPE_CHECKING @@ -542,34 +583,60 @@ class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra], def _gorg(cls): """This function exists for compatibility with old typing versions.""" assert isinstance(cls, GenericMeta) - if hasattr(cls, '_gorg'): + if hasattr(cls, "_gorg"): return cls._gorg while cls.__origin__ is not None: cls = cls.__origin__ return cls -_PROTO_WHITELIST = ['Callable', 'Awaitable', - 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', - 'ContextManager', 'AsyncContextManager'] +_PROTO_WHITELIST = [ + "Callable", + "Awaitable", + "Iterable", + "Iterator", + "AsyncIterable", + "AsyncIterator", + "Hashable", + "Sized", + "Container", + "Collection", + "Reversible", + "ContextManager", + "AsyncContextManager", +] def _get_protocol_attrs(cls): attrs = set() for base in cls.__mro__[:-1]: # without object - if base.__name__ in ('Protocol', 'Generic'): + if base.__name__ in ("Protocol", "Generic"): continue - annotations = getattr(base, '__annotations__', {}) + annotations = getattr(base, "__annotations__", {}) for attr in list(base.__dict__.keys()) + list(annotations.keys()): - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker', '_gorg')): + if not attr.startswith("_abc_") and attr not in ( + "__abstractmethods__", + "__annotations__", + "__weakref__", + "_is_protocol", + "_is_runtime_protocol", + "__dict__", + "__args__", + "__slots__", + "__next_in_mro__", + "__parameters__", + "__origin__", + "__orig_bases__", + "__extra__", + "__tree_hash__", + "__doc__", + "__subclasshook__", + "__init__", + "__new__", + "__module__", + "_MutableMapping__marker", + "_gorg", + ): attrs.add(attr) return attrs @@ -579,14 +646,14 @@ def _is_callable_members_only(cls): # 3.8+ -if hasattr(typing, 'Protocol'): +if hasattr(typing, "Protocol"): Protocol = typing.Protocol # 3.7 elif PEP_560: def _no_init(self, *args, **kwargs): if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') + raise TypeError("Protocols cannot be instantiated") class _ProtocolMeta(abc.ABCMeta): # This metaclass is a bit unfortunate and exists only because of the lack @@ -594,15 +661,20 @@ class _ProtocolMeta(abc.ABCMeta): def __instancecheck__(cls, instance): # We need this method for situations where attributes are # assigned in __init__. - if ((not getattr(cls, '_is_protocol', False) or - _is_callable_members_only(cls)) and - issubclass(instance.__class__, cls)): + if ( + not getattr(cls, "_is_protocol", False) + or _is_callable_members_only(cls) + ) and issubclass(instance.__class__, cls): return True if cls._is_protocol: - if all(hasattr(instance, attr) and - (not callable(getattr(cls, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(cls)): + if all( + hasattr(instance, attr) + and ( + not callable(getattr(cls, attr, None)) + or getattr(instance, attr) is not None + ) + for attr in _get_protocol_attrs(cls) + ): return True return super().__instancecheck__(instance) @@ -643,8 +715,10 @@ def meth(self) -> T: def __new__(cls, *args, **kwds): if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") + raise TypeError( + "Type Protocol cannot be instantiated; " + "it can only be used as a base class" + ) return super().__new__(cls) @typing._tp_cache @@ -653,7 +727,8 @@ def __class_getitem__(cls, params): params = (params,) if not params and cls is not typing.Tuple: raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") + f"Parameter list to {cls.__qualname__}[...] cannot be empty" + ) msg = "Parameters to generic types must be types." params = tuple(typing._type_check(p, msg) for p in params) # noqa if cls is Protocol: @@ -664,10 +739,10 @@ def __class_getitem__(cls, params): i += 1 raise TypeError( "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") + f" Parameter {i + 1} is {params[i]}" + ) if len(set(params)) != len(params): - raise TypeError( - "Parameters to Protocol[...] must all be unique") + raise TypeError("Parameters to Protocol[...] must all be unique") else: # Subscripting a regular Generic subclass. _check_generic(cls, params, len(cls.__parameters__)) @@ -675,13 +750,13 @@ def __class_getitem__(cls, params): def __init_subclass__(cls, *args, **kwargs): tvars = [] - if '__orig_bases__' in cls.__dict__: + if "__orig_bases__" in cls.__dict__: error = typing.Generic in cls.__orig_bases__ else: error = typing.Generic in cls.__bases__ if error: raise TypeError("Cannot inherit from plain Generic") - if '__orig_bases__' in cls.__dict__: + if "__orig_bases__" in cls.__dict__: tvars = typing._collect_type_vars(cls.__orig_bases__) # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. # If found, tvars must be a subset of it. @@ -690,14 +765,17 @@ def __init_subclass__(cls, *args, **kwargs): # and reject multiple Generic[...] and/or Protocol[...]. gvars = None for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): + if isinstance(base, typing._GenericAlias) and base.__origin__ in ( + typing.Generic, + Protocol, + ): # for error messages the_base = base.__origin__.__name__ if gvars is not None: raise TypeError( "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") + " and/or Protocol[...] multiple types." + ) gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -705,50 +783,59 @@ def __init_subclass__(cls, *args, **kwargs): tvarset = set(tvars) gvarset = set(gvars) if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") + s_vars = ", ".join(str(t) for t in tvars if t not in gvarset) + s_args = ", ".join(str(g) for g in gvars) + raise TypeError( + f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]" + ) tvars = gvars cls.__parameters__ = tuple(tvars) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get("_is_protocol", None): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get("_is_protocol", None): return NotImplemented - if not getattr(cls, '_is_runtime_protocol', False): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + if not getattr(cls, "_is_runtime_protocol", False): + if sys._getframe(2).f_globals["__name__"] in ["abc", "functools"]: return NotImplemented - raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") + raise TypeError( + "Instance and class checks can only be used with" + " @runtime protocols" + ) if not _is_callable_members_only(cls): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + if sys._getframe(2).f_globals["__name__"] in ["abc", "functools"]: return NotImplemented - raise TypeError("Protocols with non-method members" - " don't support issubclass()") + raise TypeError( + "Protocols with non-method members" + " don't support issubclass()" + ) if not isinstance(other, type): # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') + raise TypeError("issubclass() arg 1 must be a class") for attr in _get_protocol_attrs(cls): for base in other.__mro__: if attr in base.__dict__: if base.__dict__[attr] is None: return NotImplemented break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): + annotations = getattr(base, "__annotations__", {}) + if ( + isinstance(annotations, typing.Mapping) + and attr in annotations + and isinstance(other, _ProtocolMeta) + and other._is_protocol + ): break else: return NotImplemented return True - if '__subclasshook__' not in cls.__dict__: + + if "__subclasshook__" not in cls.__dict__: cls.__subclasshook__ = _proto_hook # We have nothing more to do for non-protocols. @@ -757,20 +844,27 @@ def _proto_hook(other): # Check consistency of bases. for base in cls.__bases__: - if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, _ProtocolMeta) and base._is_protocol): - raise TypeError('Protocols can only inherit from other' - f' protocols, got {repr(base)}') + if not ( + base in (object, typing.Generic) + or base.__module__ == "collections.abc" + and base.__name__ in _PROTO_WHITELIST + or isinstance(base, _ProtocolMeta) + and base._is_protocol + ): + raise TypeError( + "Protocols can only inherit from other" + f" protocols, got {repr(base)}" + ) cls.__init__ = _no_init + + # 3.6 else: from typing import _next_in_mro, _type_check # noqa def _no_init(self, *args, **kwargs): if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') + raise TypeError("Protocols cannot be instantiated") class _ProtocolMeta(GenericMeta): """Internal metaclass for Protocol. @@ -778,8 +872,18 @@ class _ProtocolMeta(GenericMeta): This exists so Protocol classes can be generic without deriving from Generic. """ - def __new__(cls, name, bases, namespace, - tvars=None, args=None, origin=None, extra=None, orig_bases=None): + + def __new__( + cls, + name, + bases, + namespace, + tvars=None, + args=None, + origin=None, + extra=None, + orig_bases=None, + ): # This is just a version copied from GenericMeta.__new__ that # includes "Protocol" special treatment. (Comments removed for brevity.) assert extra is None # Protocols should not have extra @@ -792,12 +896,15 @@ def __new__(cls, name, bases, namespace, for base in bases: if base is typing.Generic: raise TypeError("Cannot inherit from plain Generic") - if (isinstance(base, GenericMeta) and - base.__origin__ in (typing.Generic, Protocol)): + if isinstance(base, GenericMeta) and base.__origin__ in ( + typing.Generic, + Protocol, + ): if gvars is not None: raise TypeError( "Cannot inherit from Generic[...] or" - " Protocol[...] multiple times.") + " Protocol[...] multiple times." + ) gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -807,122 +914,166 @@ def __new__(cls, name, bases, namespace, if not tvarset <= gvarset: s_vars = ", ".join(str(t) for t in tvars if t not in gvarset) s_args = ", ".join(str(g) for g in gvars) - cls_name = "Generic" if any(b.__origin__ is typing.Generic - for b in bases) else "Protocol" - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {cls_name}[{s_args}]") + cls_name = ( + "Generic" + if any(b.__origin__ is typing.Generic for b in bases) + else "Protocol" + ) + raise TypeError( + f"Some type variables ({s_vars}) are" + f" not listed in {cls_name}[{s_args}]" + ) tvars = gvars initial_bases = bases - if (extra is not None and type(extra) is abc.ABCMeta and - extra not in bases): + if extra is not None and type(extra) is abc.ABCMeta and extra not in bases: bases = (extra,) + bases - bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b - for b in bases) - if any(isinstance(b, GenericMeta) and b is not typing.Generic for b in bases): + bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases) + if any( + isinstance(b, GenericMeta) and b is not typing.Generic for b in bases + ): bases = tuple(b for b in bases if b is not typing.Generic) - namespace.update({'__origin__': origin, '__extra__': extra}) - self = super(GenericMeta, cls).__new__(cls, name, bases, namespace, - _root=True) - super(GenericMeta, self).__setattr__('_gorg', - self if not origin else - _gorg(origin)) + namespace.update({"__origin__": origin, "__extra__": extra}) + self = super(GenericMeta, cls).__new__( + cls, name, bases, namespace, _root=True + ) + super(GenericMeta, self).__setattr__( + "_gorg", self if not origin else _gorg(origin) + ) self.__parameters__ = tvars - self.__args__ = tuple(... if a is typing._TypingEllipsis else - () if a is typing._TypingEmpty else - a for a in args) if args else None + self.__args__ = ( + tuple( + ( + ... + if a is typing._TypingEllipsis + else () if a is typing._TypingEmpty else a + ) + for a in args + ) + if args + else None + ) self.__next_in_mro__ = _next_in_mro(self) if orig_bases is None: self.__orig_bases__ = initial_bases elif origin is not None: self._abc_registry = origin._abc_registry self._abc_cache = origin._abc_cache - if hasattr(self, '_subs_tree'): - self.__tree_hash__ = (hash(self._subs_tree()) if origin else - super(GenericMeta, self).__hash__()) + if hasattr(self, "_subs_tree"): + self.__tree_hash__ = ( + hash(self._subs_tree()) + if origin + else super(GenericMeta, self).__hash__() + ) return self def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) - if not cls.__dict__.get('_is_protocol', None): - cls._is_protocol = any(b is Protocol or - isinstance(b, _ProtocolMeta) and - b.__origin__ is Protocol - for b in cls.__bases__) + if not cls.__dict__.get("_is_protocol", None): + cls._is_protocol = any( + b is Protocol + or isinstance(b, _ProtocolMeta) + and b.__origin__ is Protocol + for b in cls.__bases__ + ) if cls._is_protocol: for base in cls.__mro__[1:]: - if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, typing.TypingMeta) and base._is_protocol or - isinstance(base, GenericMeta) and - base.__origin__ is typing.Generic): - raise TypeError(f'Protocols can only inherit from other' - f' protocols, got {repr(base)}') + if not ( + base in (object, typing.Generic) + or base.__module__ == "collections.abc" + and base.__name__ in _PROTO_WHITELIST + or isinstance(base, typing.TypingMeta) + and base._is_protocol + or isinstance(base, GenericMeta) + and base.__origin__ is typing.Generic + ): + raise TypeError( + f"Protocols can only inherit from other" + f" protocols, got {repr(base)}" + ) cls.__init__ = _no_init def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get("_is_protocol", None): return NotImplemented if not isinstance(other, type): # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') + raise TypeError("issubclass() arg 1 must be a class") for attr in _get_protocol_attrs(cls): for base in other.__mro__: if attr in base.__dict__: if base.__dict__[attr] is None: return NotImplemented break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): + annotations = getattr(base, "__annotations__", {}) + if ( + isinstance(annotations, typing.Mapping) + and attr in annotations + and isinstance(other, _ProtocolMeta) + and other._is_protocol + ): break else: return NotImplemented return True - if '__subclasshook__' not in cls.__dict__: + + if "__subclasshook__" not in cls.__dict__: cls.__subclasshook__ = _proto_hook def __instancecheck__(self, instance): # We need this method for situations where attributes are # assigned in __init__. - if ((not getattr(self, '_is_protocol', False) or - _is_callable_members_only(self)) and - issubclass(instance.__class__, self)): + if ( + not getattr(self, "_is_protocol", False) + or _is_callable_members_only(self) + ) and issubclass(instance.__class__, self): return True if self._is_protocol: - if all(hasattr(instance, attr) and - (not callable(getattr(self, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(self)): + if all( + hasattr(instance, attr) + and ( + not callable(getattr(self, attr, None)) + or getattr(instance, attr) is not None + ) + for attr in _get_protocol_attrs(self) + ): return True return super(GenericMeta, self).__instancecheck__(instance) def __subclasscheck__(self, cls): if self.__origin__ is not None: - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError("Parameterized generics cannot be used with class " - "or instance checks") + if sys._getframe(1).f_globals["__name__"] not in ["abc", "functools"]: + raise TypeError( + "Parameterized generics cannot be used with class " + "or instance checks" + ) return False - if (self.__dict__.get('_is_protocol', None) and - not self.__dict__.get('_is_runtime_protocol', None)): - if sys._getframe(1).f_globals['__name__'] in ['abc', - 'functools', - 'typing']: + if self.__dict__.get("_is_protocol", None) and not self.__dict__.get( + "_is_runtime_protocol", None + ): + if sys._getframe(1).f_globals["__name__"] in [ + "abc", + "functools", + "typing", + ]: return False - raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") - if (self.__dict__.get('_is_runtime_protocol', None) and - not _is_callable_members_only(self)): - if sys._getframe(1).f_globals['__name__'] in ['abc', - 'functools', - 'typing']: + raise TypeError( + "Instance and class checks can only be used with" + " @runtime protocols" + ) + if self.__dict__.get( + "_is_runtime_protocol", None + ) and not _is_callable_members_only(self): + if sys._getframe(1).f_globals["__name__"] in [ + "abc", + "functools", + "typing", + ]: return super(GenericMeta, self).__subclasscheck__(cls) - raise TypeError("Protocols with non-method members" - " don't support issubclass()") + raise TypeError( + "Protocols with non-method members" " don't support issubclass()" + ) return super(GenericMeta, self).__subclasscheck__(cls) @typing._tp_cache @@ -933,16 +1084,19 @@ def __getitem__(self, params): params = (params,) if not params and _gorg(self) is not typing.Tuple: raise TypeError( - f"Parameter list to {self.__qualname__}[...] cannot be empty") + f"Parameter list to {self.__qualname__}[...] cannot be empty" + ) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) if self in (typing.Generic, Protocol): if not all(isinstance(p, typing.TypeVar) for p in params): raise TypeError( - f"Parameters to {repr(self)}[...] must all be type variables") + f"Parameters to {repr(self)}[...] must all be type variables" + ) if len(set(params)) != len(params): raise TypeError( - f"Parameters to {repr(self)}[...] must all be unique") + f"Parameters to {repr(self)}[...] must all be unique" + ) tvars = params args = params elif self in (typing.Tuple, typing.Callable): @@ -956,14 +1110,16 @@ def __getitem__(self, params): args = params prepend = (self,) if self.__origin__ is None else () - return self.__class__(self.__name__, - prepend + self.__bases__, - _no_slots_copy(self.__dict__), - tvars=tvars, - args=args, - origin=self, - extra=self.__extra__, - orig_bases=self.__orig_bases__) + return self.__class__( + self.__name__, + prepend + self.__bases__, + _no_slots_copy(self.__dict__), + tvars=tvars, + args=args, + origin=self, + extra=self.__extra__, + orig_bases=self.__orig_bases__, + ) class Protocol(metaclass=_ProtocolMeta): """Base class for protocol classes. Protocol classes are defined as:: @@ -994,21 +1150,25 @@ class GenProto(Protocol[T]): def meth(self) -> T: ... """ + __slots__ = () _is_protocol = True def __new__(cls, *args, **kwds): if _gorg(cls) is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can be used only as a base class") + raise TypeError( + "Type Protocol cannot be instantiated; " + "it can be used only as a base class" + ) return typing._generic_new(cls.__next_in_mro__, cls, *args, **kwds) # 3.8+ -if hasattr(typing, 'runtime_checkable'): +if hasattr(typing, "runtime_checkable"): runtime_checkable = typing.runtime_checkable # 3.6-3.7 else: + def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it can be used with isinstance() and issubclass(). Raise TypeError @@ -1018,8 +1178,10 @@ def runtime_checkable(cls): one-offs in collections.abc such as Hashable. """ if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: - raise TypeError('@runtime_checkable can be only applied to protocol classes,' - f' got {cls!r}') + raise TypeError( + "@runtime_checkable can be only applied to protocol classes," + f" got {cls!r}" + ) cls._is_runtime_protocol = True return cls @@ -1029,10 +1191,11 @@ def runtime_checkable(cls): # 3.8+ -if hasattr(typing, 'SupportsIndex'): +if hasattr(typing, "SupportsIndex"): SupportsIndex = typing.SupportsIndex # 3.6-3.7 else: + @runtime_checkable class SupportsIndex(Protocol): __slots__ = () @@ -1053,71 +1216,91 @@ def __index__(self) -> int: _TypedDictMeta = typing._TypedDictMeta is_typeddict = typing.is_typeddict else: + def _check_fails(cls, other): try: - if sys._getframe(1).f_globals['__name__'] not in ['abc', - 'functools', - 'typing']: + if sys._getframe(1).f_globals["__name__"] not in [ + "abc", + "functools", + "typing", + ]: # Typed dicts are only for static structural subtyping. - raise TypeError('TypedDict does not support instance and class checks') + raise TypeError("TypedDict does not support instance and class checks") except (AttributeError, ValueError): pass return False def _dict_new(*args, **kwargs): if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') + raise TypeError("TypedDict.__new__(): not enough arguments") _, args = args[0], args[1:] # allow the "cls" keyword be passed return dict(*args, **kwargs) - _dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)' + _dict_new.__text_signature__ = "($cls, _typename, _fields=None, /, **kwargs)" def _typeddict_new(*args, total=True, **kwargs): if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') + raise TypeError("TypedDict.__new__(): not enough arguments") _, args = args[0], args[1:] # allow the "cls" keyword be passed if args: - typename, args = args[0], args[1:] # allow the "_typename" keyword be passed - elif '_typename' in kwargs: - typename = kwargs.pop('_typename') + typename, args = ( + args[0], + args[1:], + ) # allow the "_typename" keyword be passed + elif "_typename" in kwargs: + typename = kwargs.pop("_typename") import warnings - warnings.warn("Passing '_typename' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) + + warnings.warn( + "Passing '_typename' as keyword argument is deprecated", + DeprecationWarning, + stacklevel=2, + ) else: - raise TypeError("TypedDict.__new__() missing 1 required positional " - "argument: '_typename'") + raise TypeError( + "TypedDict.__new__() missing 1 required positional " + "argument: '_typename'" + ) if args: try: - fields, = args # allow the "_fields" keyword be passed + (fields,) = args # allow the "_fields" keyword be passed except ValueError: - raise TypeError('TypedDict.__new__() takes from 2 to 3 ' - f'positional arguments but {len(args) + 2} ' - 'were given') - elif '_fields' in kwargs and len(kwargs) == 1: - fields = kwargs.pop('_fields') + raise TypeError( + "TypedDict.__new__() takes from 2 to 3 " + f"positional arguments but {len(args) + 2} " + "were given" + ) + elif "_fields" in kwargs and len(kwargs) == 1: + fields = kwargs.pop("_fields") import warnings - warnings.warn("Passing '_fields' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) + + warnings.warn( + "Passing '_fields' as keyword argument is deprecated", + DeprecationWarning, + stacklevel=2, + ) else: fields = None if fields is None: fields = kwargs elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") + raise TypeError( + "TypedDict takes either a dict or keyword arguments," " but not both" + ) - ns = {'__annotations__': dict(fields)} + ns = {"__annotations__": dict(fields)} try: # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') + ns["__module__"] = sys._getframe(1).f_globals.get("__name__", "__main__") except (AttributeError, ValueError): pass return _TypedDictMeta(typename, (), ns, total=total) - _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' - ' /, *, total=True, **kwargs)') + _typeddict_new.__text_signature__ = ( + "($cls, _typename, _fields=None," " /, *, total=True, **kwargs)" + ) class _TypedDictMeta(type): def __init__(cls, name, bases, ns, total=True): @@ -1130,11 +1313,11 @@ def __new__(cls, name, bases, ns, total=True): # TypedDict supports all three syntaxes described in its docstring. # Subclasses and instances of TypedDict return actual dictionaries # via _dict_new. - ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new + ns["__new__"] = _typeddict_new if name == "TypedDict" else _dict_new tp_dict = super().__new__(cls, name, (dict,), ns) annotations = {} - own_annotations = ns.get('__annotations__', {}) + own_annotations = ns.get("__annotations__", {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" own_annotations = { n: typing._type_check(tp, msg) for n, tp in own_annotations.items() @@ -1143,9 +1326,9 @@ def __new__(cls, name, bases, ns, total=True): optional_keys = set() for base in bases: - annotations.update(base.__dict__.get('__annotations__', {})) - required_keys.update(base.__dict__.get('__required_keys__', ())) - optional_keys.update(base.__dict__.get('__optional_keys__', ())) + annotations.update(base.__dict__.get("__annotations__", {})) + required_keys.update(base.__dict__.get("__required_keys__", ())) + optional_keys.update(base.__dict__.get("__optional_keys__", ())) annotations.update(own_annotations) if PEP_560: @@ -1175,16 +1358,15 @@ def __new__(cls, name, bases, ns, total=True): tp_dict.__annotations__ = annotations tp_dict.__required_keys__ = frozenset(required_keys) tp_dict.__optional_keys__ = frozenset(optional_keys) - if not hasattr(tp_dict, '__total__'): + if not hasattr(tp_dict, "__total__"): tp_dict.__total__ = total return tp_dict __instancecheck__ = __subclasscheck__ = _check_fails - TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) + TypedDict = _TypedDictMeta("TypedDict", (dict,), {}) TypedDict.__module__ = __name__ - TypedDict.__doc__ = \ - """A simple typed name space. At runtime it is equivalent to a plain dict. + TypedDict.__doc__ = """A simple typed name space. At runtime it is equivalent to a plain dict. TypedDict creates a dictionary type that expects all of its instances to have a certain set of keys, with each key @@ -1231,6 +1413,7 @@ class Film(TypedDict): """ return isinstance(tp, tuple(_TYPEDDICT_TYPES)) + if hasattr(typing, "Required"): get_type_hints = typing.get_type_hints elif PEP_560: @@ -1306,13 +1489,14 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): # Python 3.9+ has PEP 593 (Annotated) -if hasattr(typing, 'Annotated'): +if hasattr(typing, "Annotated"): Annotated = typing.Annotated # Not exported and not a public API, but needed for get_origin() and get_args() # to work. _AnnotatedAlias = typing._AnnotatedAlias # 3.7-3.8 elif PEP_560: + class _AnnotatedAlias(typing._GenericAlias, _root=True): """Runtime representation of an annotated type. @@ -1321,6 +1505,7 @@ class _AnnotatedAlias(typing._GenericAlias, _root=True): instantiating is the same as instantiating the underlying type, binding it to types is also the same. """ + def __init__(self, origin, metadata): if isinstance(origin, _AnnotatedAlias): metadata = origin.__metadata__ + metadata @@ -1334,13 +1519,13 @@ def copy_with(self, params): return _AnnotatedAlias(new_type, self.__metadata__) def __repr__(self): - return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " - f"{', '.join(repr(a) for a in self.__metadata__)}]") + return ( + f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]" + ) def __reduce__(self): - return operator.getitem, ( - Annotated, (self.__origin__,) + self.__metadata__ - ) + return operator.getitem, (Annotated, (self.__origin__,) + self.__metadata__) def __eq__(self, other): if not isinstance(other, _AnnotatedAlias): @@ -1393,9 +1578,11 @@ def __new__(cls, *args, **kwargs): @typing._tp_cache def __class_getitem__(cls, params): if not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be used " - "with at least two arguments (a type and an " - "annotation).") + raise TypeError( + "Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation)." + ) allowed_special_forms = (ClassVar, Final) if get_origin(params[0]) in allowed_special_forms: origin = params[0] @@ -1406,15 +1593,15 @@ def __class_getitem__(cls, params): return _AnnotatedAlias(origin, metadata) def __init_subclass__(cls, *args, **kwargs): - raise TypeError( - f"Cannot subclass {cls.__module__}.Annotated" - ) + raise TypeError(f"Cannot subclass {cls.__module__}.Annotated") + + # 3.6 else: def _is_dunder(name): """Returns True if name is a __dunder_variable_name__.""" - return len(name) > 4 and name.startswith('__') and name.endswith('__') + return len(name) > 4 and name.startswith("__") and name.endswith("__") # Prior to Python 3.7 types did not have `copy_with`. A lot of the equality # checks, argument expansion etc. are done on the _subs_tre. As a result we @@ -1439,7 +1626,7 @@ def _tree_repr(self, tree): else: tp_repr = origin[0]._tree_repr(origin) metadata_reprs = ", ".join(repr(arg) for arg in metadata) - return f'{cls}[{tp_repr}, {metadata_reprs}]' + return f"{cls}[{tp_repr}, {metadata_reprs}]" def _subs_tree(self, tvars=None, args=None): # noqa if self is Annotated: @@ -1455,8 +1642,10 @@ def _subs_tree(self, tvars=None, args=None): # noqa def _get_cons(self): """Return the class used to create instance of this type.""" if self.__origin__ is None: - raise TypeError("Cannot get the underlying type of a " - "non-specialized Annotated type.") + raise TypeError( + "Cannot get the underlying type of a " + "non-specialized Annotated type." + ) tree = self._subs_tree() while isinstance(tree, tuple) and tree[0] is Annotated: tree = tree[1] @@ -1472,13 +1661,15 @@ def __getitem__(self, params): if self.__origin__ is not None: # specializing an instantiated type return super().__getitem__(params) elif not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be instantiated " - "with at least two arguments (a type and an " - "annotation).") + raise TypeError( + "Annotated[...] should be instantiated " + "with at least two arguments (a type and an " + "annotation)." + ) else: if ( - isinstance(params[0], typing._TypingBase) and - type(params[0]).__name__ == "_ClassVar" + isinstance(params[0], typing._TypingBase) + and type(params[0]).__name__ == "_ClassVar" ): tp = params[0] else: @@ -1511,7 +1702,7 @@ def __getattr__(self, attr): raise AttributeError(attr) def __setattr__(self, attr, value): - if _is_dunder(attr) or attr.startswith('_abc_'): + if _is_dunder(attr) or attr.startswith("_abc_"): super().__setattr__(attr, value) elif self.__origin__ is None: raise AttributeError(attr) @@ -1556,6 +1747,7 @@ class Annotated(metaclass=AnnotatedMeta): OptimizedList[int] == Annotated[List[int], runtime.Optimize()] """ + # Python 3.8 has get_origin() and get_args() but those implementations aren't # Annotated-aware, so we can't use those. Python 3.9's versions don't support # ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. @@ -1592,8 +1784,16 @@ def get_origin(tp): """ if isinstance(tp, _AnnotatedAlias): return Annotated - if isinstance(tp, (typing._GenericAlias, GenericAlias, _BaseGenericAlias, - ParamSpecArgs, ParamSpecKwargs)): + if isinstance( + tp, + ( + typing._GenericAlias, + GenericAlias, + _BaseGenericAlias, + ParamSpecArgs, + ParamSpecKwargs, + ), + ): return tp.__origin__ if tp is typing.Generic: return typing.Generic @@ -1623,13 +1823,14 @@ def get_args(tp): # 3.10+ -if hasattr(typing, 'TypeAlias'): +if hasattr(typing, "TypeAlias"): TypeAlias = typing.TypeAlias # 3.9 elif sys.version_info[:2] >= (3, 9): + class _TypeAliasForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name @_TypeAliasForm def TypeAlias(self, parameters): @@ -1644,14 +1845,18 @@ def TypeAlias(self, parameters): It's invalid when used anywhere except as in the example above. """ raise TypeError(f"{self} is not subscriptable") + + # 3.7-3.8 elif sys.version_info[:2] >= (3, 7): + class _TypeAliasForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name - TypeAlias = _TypeAliasForm('TypeAlias', - doc="""Special marker indicating that an assignment should + TypeAlias = _TypeAliasForm( + "TypeAlias", + doc="""Special marker indicating that an assignment should be recognized as a proper type alias definition by type checkers. @@ -1660,14 +1865,16 @@ def __repr__(self): Predicate: TypeAlias = Callable[..., bool] It's invalid when used anywhere except as in the example - above.""") + above.""", + ) # 3.6 else: + class _TypeAliasMeta(typing.TypingMeta): """Metaclass for TypeAlias""" def __repr__(self): - return 'typing_extensions.TypeAlias' + return "typing_extensions.TypeAlias" class _TypeAliasBase(typing._FinalTypingBase, metaclass=_TypeAliasMeta, _root=True): """Special marker indicating that an assignment should @@ -1680,6 +1887,7 @@ class _TypeAliasBase(typing._FinalTypingBase, metaclass=_TypeAliasMeta, _root=Tr It's invalid when used anywhere except as in the example above. """ + __slots__ = () def __instancecheck__(self, obj): @@ -1689,19 +1897,21 @@ def __subclasscheck__(self, cls): raise TypeError("TypeAlias cannot be used with issubclass().") def __repr__(self): - return 'typing_extensions.TypeAlias' + return "typing_extensions.TypeAlias" TypeAlias = _TypeAliasBase(_root=True) # Python 3.10+ has PEP 612 -if hasattr(typing, 'ParamSpecArgs'): +if hasattr(typing, "ParamSpecArgs"): ParamSpecArgs = typing.ParamSpecArgs ParamSpecKwargs = typing.ParamSpecKwargs # 3.6-3.9 else: + class _Immutable: """Mixin to indicate that object should not be copied.""" + __slots__ = () def __copy__(self): @@ -1722,6 +1932,7 @@ class ParamSpecArgs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1745,6 +1956,7 @@ class ParamSpecKwargs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1756,8 +1968,9 @@ def __eq__(self, other): return NotImplemented return self.__origin__ == other.__origin__ + # 3.10+ -if hasattr(typing, 'ParamSpec'): +if hasattr(typing, "ParamSpec"): ParamSpec = typing.ParamSpec # 3.6-3.9 else: @@ -1827,25 +2040,25 @@ def __init__(self, name, *, bound=None, covariant=False, contravariant=False): self.__covariant__ = bool(covariant) self.__contravariant__ = bool(contravariant) if bound: - self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + self.__bound__ = typing._type_check(bound, "Bound must be a type.") else: self.__bound__ = None # for pickling: try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + def_mod = sys._getframe(1).f_globals.get("__name__", "__main__") except (AttributeError, ValueError): def_mod = None - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod def __repr__(self): if self.__covariant__: - prefix = '+' + prefix = "+" elif self.__contravariant__: - prefix = '-' + prefix = "-" else: - prefix = '~' + prefix = "~" return prefix + self.__name__ def __hash__(self): @@ -1869,7 +2082,7 @@ def _get_type_vars(self, tvars): # 3.6-3.9 -if not hasattr(typing, 'Concatenate'): +if not hasattr(typing, "Concatenate"): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. class _ConcatenateGenericAlias(list): @@ -1891,8 +2104,10 @@ def __init__(self, origin, args): def __repr__(self): _type_repr = typing._type_repr - return (f'{_type_repr(self.__origin__)}' - f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + return ( + f"{_type_repr(self.__origin__)}" + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]' + ) def __hash__(self): return hash((self.__origin__, self.__args__)) @@ -1904,7 +2119,9 @@ def __call__(self, *args, **kwargs): @property def __parameters__(self): return tuple( - tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + tp + for tp in self.__args__ + if isinstance(tp, (typing.TypeVar, ParamSpec)) ) if not PEP_560: @@ -1922,19 +2139,21 @@ def _concatenate_getitem(self, parameters): if not isinstance(parameters, tuple): parameters = (parameters,) if not isinstance(parameters[-1], ParamSpec): - raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") + raise TypeError( + "The last parameter to Concatenate should be a " "ParamSpec variable." + ) msg = "Concatenate[arg, ...]: each arg must be a type." parameters = tuple(typing._type_check(p, msg) for p in parameters) return _ConcatenateGenericAlias(self, parameters) # 3.10+ -if hasattr(typing, 'Concatenate'): +if hasattr(typing, "Concatenate"): Concatenate = typing.Concatenate - _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa # 3.9 elif sys.version_info[:2] >= (3, 9): + @_TypeAliasForm def Concatenate(self, parameters): """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a @@ -1948,17 +2167,20 @@ def Concatenate(self, parameters): See PEP 612 for detailed information. """ return _concatenate_getitem(self, parameters) + + # 3.7-8 elif sys.version_info[:2] >= (3, 7): + class _ConcatenateForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): return _concatenate_getitem(self, parameters) Concatenate = _ConcatenateForm( - 'Concatenate', + "Concatenate", doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a higher order function which adds, removes or transforms parameters of a callable. @@ -1968,18 +2190,20 @@ def __getitem__(self, parameters): Callable[Concatenate[int, P], int] See PEP 612 for detailed information. - """) + """, + ) # 3.6 else: + class _ConcatenateAliasMeta(typing.TypingMeta): """Metaclass for Concatenate.""" def __repr__(self): - return 'typing_extensions.Concatenate' + return "typing_extensions.Concatenate" - class _ConcatenateAliasBase(typing._FinalTypingBase, - metaclass=_ConcatenateAliasMeta, - _root=True): + class _ConcatenateAliasBase( + typing._FinalTypingBase, metaclass=_ConcatenateAliasMeta, _root=True + ): """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a higher order function which adds, removes or transforms parameters of a callable. @@ -1990,6 +2214,7 @@ class _ConcatenateAliasBase(typing._FinalTypingBase, See PEP 612 for detailed information. """ + __slots__ = () def __instancecheck__(self, obj): @@ -1999,7 +2224,7 @@ def __subclasscheck__(self, cls): raise TypeError("Concatenate cannot be used with issubclass().") def __repr__(self): - return 'typing_extensions.Concatenate' + return "typing_extensions.Concatenate" def __getitem__(self, parameters): return _concatenate_getitem(self, parameters) @@ -2007,13 +2232,14 @@ def __getitem__(self, parameters): Concatenate = _ConcatenateAliasBase(_root=True) # 3.10+ -if hasattr(typing, 'TypeGuard'): +if hasattr(typing, "TypeGuard"): TypeGuard = typing.TypeGuard # 3.9 elif sys.version_info[:2] >= (3, 9): + class _TypeGuardForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name @_TypeGuardForm def TypeGuard(self, parameters): @@ -2059,22 +2285,26 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). """ - item = typing._type_check(parameters, f'{self} accepts only single type.') + item = typing._type_check(parameters, f"{self} accepts only single type.") return typing._GenericAlias(self, (item,)) + + # 3.7-3.8 elif sys.version_info[:2] >= (3, 7): + class _TypeGuardForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type" + ) return typing._GenericAlias(self, (item,)) TypeGuard = _TypeGuardForm( - 'TypeGuard', + "TypeGuard", doc="""Special typing form used to annotate the return type of a user-defined type guard function. ``TypeGuard`` only accepts a single type argument. At runtime, functions marked this way should return a boolean. @@ -2116,9 +2346,11 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). - """) + """, + ) # 3.6 else: + class _TypeGuard(typing._FinalTypingBase, _root=True): """Special typing form used to annotate the return type of a user-defined type guard function. ``TypeGuard`` only accepts a single type argument. @@ -2163,7 +2395,7 @@ def is_str(val: Union[str, float]): PEP 647 (User-Defined Type Guards). """ - __slots__ = ('__type__',) + __slots__ = ("__type__",) def __init__(self, tp=None, **kwds): self.__type__ = tp @@ -2171,10 +2403,13 @@ def __init__(self, tp=None, **kwds): def __getitem__(self, item): cls = type(self) if self.__type__ is None: - return cls(typing._type_check(item, - f'{cls.__name__[1:]} accepts only a single type.'), - _root=True) - raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted') + return cls( + typing._type_check( + item, f"{cls.__name__[1:]} accepts only a single type." + ), + _root=True, + ) + raise TypeError(f"{cls.__name__[1:]} cannot be further subscripted") def _eval_type(self, globalns, localns): new_tp = typing._eval_type(self.__type__, globalns, localns) @@ -2185,7 +2420,7 @@ def _eval_type(self, globalns, localns): def __repr__(self): r = super().__repr__() if self.__type__ is not None: - r += f'[{typing._type_repr(self.__type__)}]' + r += f"[{typing._type_repr(self.__type__)}]" return r def __hash__(self): @@ -2204,7 +2439,7 @@ def __eq__(self, other): if sys.version_info[:2] >= (3, 7): # Vendored from cpython typing._SpecialFrom class _SpecialForm(typing._Final, _root=True): - __slots__ = ('_name', '__doc__', '_getitem') + __slots__ = ("_name", "__doc__", "_getitem") def __init__(self, getitem): self._getitem = getitem @@ -2212,7 +2447,7 @@ def __init__(self, getitem): self.__doc__ = getitem.__doc__ def __getattr__(self, item): - if item in {'__name__', '__qualname__'}: + if item in {"__name__", "__qualname__"}: return self._name raise AttributeError(item) @@ -2221,7 +2456,7 @@ def __mro_entries__(self, bases): raise TypeError(f"Cannot subclass {self!r}") def __repr__(self): - return f'typing_extensions.{self._name}' + return f"typing_extensions.{self._name}" def __reduce__(self): return self._name @@ -2249,6 +2484,7 @@ def __getitem__(self, parameters): if hasattr(typing, "LiteralString"): LiteralString = typing.LiteralString elif sys.version_info[:2] >= (3, 7): + @_SpecialForm def LiteralString(self, params): """Represents an arbitrary literal string. @@ -2267,7 +2503,9 @@ def query(sql: LiteralString) -> ...: """ raise TypeError(f"{self} is not subscriptable") + else: + class _LiteralString(typing._FinalTypingBase, _root=True): """Represents an arbitrary literal string. @@ -2299,6 +2537,7 @@ def __subclasscheck__(self, cls): if hasattr(typing, "Self"): Self = typing.Self elif sys.version_info[:2] >= (3, 7): + @_SpecialForm def Self(self, params): """Used to spell the type of "self" in classes. @@ -2315,7 +2554,9 @@ def parse(self, data: bytes) -> Self: """ raise TypeError(f"{self} is not subscriptable") + else: + class _Self(typing._FinalTypingBase, _root=True): """Used to spell the type of "self" in classes. @@ -2344,6 +2585,7 @@ def __subclasscheck__(self, cls): if hasattr(typing, "Never"): Never = typing.Never elif sys.version_info[:2] >= (3, 7): + @_SpecialForm def Never(self, params): """The bottom type, a type that has no members. @@ -2369,7 +2611,9 @@ def int_or_str(arg: int | str) -> None: """ raise TypeError(f"{self} is not subscriptable") + else: + class _Never(typing._FinalTypingBase, _root=True): """The bottom type, a type that has no members. @@ -2404,13 +2648,14 @@ def __subclasscheck__(self, cls): Never = _Never(_root=True) -if hasattr(typing, 'Required'): +if hasattr(typing, "Required"): Required = typing.Required NotRequired = typing.NotRequired elif sys.version_info[:2] >= (3, 9): + class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name @_ExtensionsSpecialForm def Required(self, parameters): @@ -2429,7 +2674,7 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') + item = typing._type_check(parameters, f"{self._name} accepts only single type") return typing._GenericAlias(self, (item,)) @_ExtensionsSpecialForm @@ -2446,21 +2691,23 @@ class Movie(TypedDict): year=1999, ) """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') + item = typing._type_check(parameters, f"{self._name} accepts only single type") return typing._GenericAlias(self, (item,)) elif sys.version_info[:2] >= (3, 7): + class _RequiredForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): - item = typing._type_check(parameters, - '{} accepts only single type'.format(self._name)) + item = typing._type_check( + parameters, "{} accepts only single type".format(self._name) + ) return typing._GenericAlias(self, (item,)) Required = _RequiredForm( - 'Required', + "Required", doc="""A special typing construct to mark a key of a total=False TypedDict as required. For example: @@ -2475,9 +2722,10 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. - """) + """, + ) NotRequired = _RequiredForm( - 'NotRequired', + "NotRequired", doc="""A special typing construct to mark a key of a TypedDict as potentially missing. For example: @@ -2489,11 +2737,12 @@ class Movie(TypedDict): title='The Matrix', # typechecker error if key is omitted year=1999, ) - """) + """, + ) else: # NOTE: Modeled after _Final's implementation when _FinalTypingBase available class _MaybeRequired(typing._FinalTypingBase, _root=True): - __slots__ = ('__type__',) + __slots__ = ("__type__",) def __init__(self, tp=None, **kwds): self.__type__ = tp @@ -2501,11 +2750,13 @@ def __init__(self, tp=None, **kwds): def __getitem__(self, item): cls = type(self) if self.__type__ is None: - return cls(typing._type_check(item, - '{} accepts only single type.'.format(cls.__name__[1:])), - _root=True) - raise TypeError('{} cannot be further subscripted' - .format(cls.__name__[1:])) + return cls( + typing._type_check( + item, "{} accepts only single type.".format(cls.__name__[1:]) + ), + _root=True, + ) + raise TypeError("{} cannot be further subscripted".format(cls.__name__[1:])) def _eval_type(self, globalns, localns): new_tp = typing._eval_type(self.__type__, globalns, localns) @@ -2516,7 +2767,7 @@ def _eval_type(self, globalns, localns): def __repr__(self): r = super().__repr__() if self.__type__ is not None: - r += '[{}]'.format(typing._type_repr(self.__type__)) + r += "[{}]".format(typing._type_repr(self.__type__)) return r def __hash__(self): @@ -2565,9 +2816,10 @@ class Movie(TypedDict): if sys.version_info[:2] >= (3, 9): + class _UnpackSpecialForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar @@ -2576,35 +2828,37 @@ class _UnpackAlias(typing._GenericAlias, _root=True): def Unpack(self, parameters): """A special typing construct to unpack a variadic type. For example: - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... """ - item = typing._type_check(parameters, f'{self._name} accepts only single type') + item = typing._type_check(parameters, f"{self._name} accepts only single type") return _UnpackAlias(self, (item,)) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) elif sys.version_info[:2] >= (3, 7): + class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar class _UnpackForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only single type') + item = typing._type_check( + parameters, f"{self._name} accepts only single type" + ) return _UnpackAlias(self, (item,)) Unpack = _UnpackForm( - 'Unpack', + "Unpack", doc="""A special typing construct to unpack a variadic type. For example: Shape = TypeVarTuple('Shape') @@ -2614,7 +2868,8 @@ def add_batch_axis( x: Array[Unpack[Shape]] ) -> Array[Batch, Unpack[Shape]]: ... - """) + """, + ) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) @@ -2624,15 +2879,16 @@ def _is_unpack(obj): class _Unpack(typing._FinalTypingBase, _root=True): """A special typing construct to unpack a variadic type. For example: - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... """ - __slots__ = ('__type__',) + + __slots__ = ("__type__",) __class__ = typing.TypeVar def __init__(self, tp=None, **kwds): @@ -2641,10 +2897,11 @@ def __init__(self, tp=None, **kwds): def __getitem__(self, item): cls = type(self) if self.__type__ is None: - return cls(typing._type_check(item, - 'Unpack accepts only single type.'), - _root=True) - raise TypeError('Unpack cannot be further subscripted') + return cls( + typing._type_check(item, "Unpack accepts only single type."), + _root=True, + ) + raise TypeError("Unpack cannot be further subscripted") def _eval_type(self, globalns, localns): new_tp = typing._eval_type(self.__type__, globalns, localns) @@ -2655,7 +2912,7 @@ def _eval_type(self, globalns, localns): def __repr__(self): r = super().__repr__() if self.__type__ is not None: - r += '[{}]'.format(typing._type_repr(self.__type__)) + r += "[{}]".format(typing._type_repr(self.__type__)) return r def __hash__(self): @@ -2733,10 +2990,10 @@ def __init__(self, name): # for pickling: try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + def_mod = sys._getframe(1).f_globals.get("__name__", "__main__") except (AttributeError, ValueError): def_mod = None - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod self.__unpacked__ = Unpack[self] @@ -2754,7 +3011,7 @@ def __reduce__(self): return self.__name__ def __init_subclass__(self, *args, **kwds): - if '_root' not in kwds: + if "_root" not in kwds: raise TypeError("Cannot subclass special typing classes") if not PEP_560: @@ -2767,6 +3024,7 @@ def _get_type_vars(self, tvars): if hasattr(typing, "reveal_type"): reveal_type = typing.reveal_type else: + def reveal_type(__obj: T) -> T: """Reveal the inferred type of a variable. @@ -2790,6 +3048,7 @@ def reveal_type(__obj: T) -> T: if hasattr(typing, "assert_never"): assert_never = typing.assert_never else: + def assert_never(__arg: Never) -> Never: """Assert to the type checker that a line of code is unreachable. @@ -2813,17 +3072,17 @@ def int_or_str(arg: int | str) -> None: raise AssertionError("Expected code to be unreachable") -if hasattr(typing, 'dataclass_transform'): +if hasattr(typing, "dataclass_transform"): dataclass_transform = typing.dataclass_transform else: + def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, field_descriptors: typing.Tuple[ - typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], - ... + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], ... ] = (), ) -> typing.Callable[[T], T]: """Decorator that marks a function, class, or metaclass as providing @@ -2885,6 +3144,7 @@ class CustomerModel(ModelBase): See PEP 681 for details. """ + def decorator(cls_or_fn): cls_or_fn.__dataclass_transform__ = { "eq_default": eq_default, @@ -2893,6 +3153,7 @@ def decorator(cls_or_fn): "field_descriptors": field_descriptors, } return cls_or_fn + return decorator diff --git a/metaflow/_vendor/v3_6/zipp.py b/metaflow/_vendor/v3_6/zipp.py index 26b723c1fd3..72632b0b773 100644 --- a/metaflow/_vendor/v3_6/zipp.py +++ b/metaflow/_vendor/v3_6/zipp.py @@ -12,7 +12,7 @@ OrderedDict = dict -__all__ = ['Path'] +__all__ = ["Path"] def _parents(path): @@ -93,7 +93,7 @@ def resolve_dir(self, name): as a directory (with the trailing slash). """ names = self._name_set() - dirname = name + '/' + dirname = name + "/" dir_match = name not in names and dirname in names return dirname if dir_match else name @@ -110,7 +110,7 @@ def make(cls, source): return cls(_pathlib_compat(source)) # Only allow for FastLookup when supplied zipfile is read-only - if 'r' not in source.mode: + if "r" not in source.mode: cls = CompleteDirs source.__class__ = cls @@ -240,7 +240,7 @@ def __init__(self, root, at=""): self.root = FastLookup.make(root) self.at = at - def open(self, mode='r', *args, pwd=None, **kwargs): + def open(self, mode="r", *args, pwd=None, **kwargs): """ Open this entry as text or binary following the semantics of ``pathlib.Path.open()`` by passing arguments through @@ -249,10 +249,10 @@ def open(self, mode='r', *args, pwd=None, **kwargs): if self.is_dir(): raise IsADirectoryError(self) zip_mode = mode[0] - if not self.exists() and zip_mode == 'r': + if not self.exists() and zip_mode == "r": raise FileNotFoundError(self) stream = self.root.open(self.at, zip_mode, pwd=pwd) - if 'b' in mode: + if "b" in mode: if args or kwargs: raise ValueError("encoding args invalid for binary operation") return stream @@ -279,11 +279,11 @@ def filename(self): return pathlib.Path(self.root.filename).joinpath(self.at) def read_text(self, *args, **kwargs): - with self.open('r', *args, **kwargs) as strm: + with self.open("r", *args, **kwargs) as strm: return strm.read() def read_bytes(self): - with self.open('rb') as strm: + with self.open("rb") as strm: return strm.read() def _is_child(self, path): @@ -323,7 +323,7 @@ def joinpath(self, *other): def parent(self): if not self.at: return self.filename.parent - parent_at = posixpath.dirname(self.at.rstrip('/')) + parent_at = posixpath.dirname(self.at.rstrip("/")) if parent_at: - parent_at += '/' + parent_at += "/" return self._next(parent_at) diff --git a/metaflow/_vendor/v3_7/__init__.py b/metaflow/_vendor/v3_7/__init__.py index 22ae0c5f40e..932b79829cf 100644 --- a/metaflow/_vendor/v3_7/__init__.py +++ b/metaflow/_vendor/v3_7/__init__.py @@ -1 +1 @@ -# Empty file \ No newline at end of file +# Empty file diff --git a/metaflow/_vendor/v3_7/importlib_metadata/__init__.py b/metaflow/_vendor/v3_7/importlib_metadata/__init__.py index 443f4763c00..2a29aa5a039 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/__init__.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/__init__.py @@ -33,18 +33,18 @@ __all__ = [ - 'Distribution', - 'DistributionFinder', - 'PackageMetadata', - 'PackageNotFoundError', - 'distribution', - 'distributions', - 'entry_points', - 'files', - 'metadata', - 'packages_distributions', - 'requires', - 'version', + "Distribution", + "DistributionFinder", + "PackageMetadata", + "PackageNotFoundError", + "distribution", + "distributions", + "entry_points", + "files", + "metadata", + "packages_distributions", + "requires", + "version", ] @@ -114,15 +114,15 @@ def read(text, filter_=None): lines = filter(filter_, map(str.strip, text.splitlines())) name = None for value in lines: - section_match = value.startswith('[') and value.endswith(']') + section_match = value.startswith("[") and value.endswith("]") if section_match: - name = value.strip('[]') + name = value.strip("[]") continue yield Pair(name, value) @staticmethod def valid(line): - return line and not line.startswith('#') + return line and not line.startswith("#") class DeprecatedTuple: @@ -160,9 +160,9 @@ class EntryPoint(DeprecatedTuple): """ pattern = re.compile( - r'(?P[\w.]+)\s*' - r'(:\s*(?P[\w.]+))?\s*' - r'(?P\[.*\])?\s*$' + r"(?P[\w.]+)\s*" + r"(:\s*(?P[\w.]+))?\s*" + r"(?P\[.*\])?\s*$" ) """ A regular expression describing the syntax for an entry point, @@ -180,7 +180,7 @@ class EntryPoint(DeprecatedTuple): following the attr, and following any extras. """ - dist: Optional['Distribution'] = None + dist: Optional["Distribution"] = None def __init__(self, name, value, group): vars(self).update(name=name, value=value, group=group) @@ -191,24 +191,24 @@ def load(self): return the named object. """ match = self.pattern.match(self.value) - module = import_module(match.group('module')) - attrs = filter(None, (match.group('attr') or '').split('.')) + module = import_module(match.group("module")) + attrs = filter(None, (match.group("attr") or "").split(".")) return functools.reduce(getattr, attrs, module) @property def module(self): match = self.pattern.match(self.value) - return match.group('module') + return match.group("module") @property def attr(self): match = self.pattern.match(self.value) - return match.group('attr') + return match.group("attr") @property def extras(self): match = self.pattern.match(self.value) - return list(re.finditer(r'\w+', match.group('extras') or '')) + return list(re.finditer(r"\w+", match.group("extras") or "")) def _for(self, dist): vars(self).update(dist=dist) @@ -243,8 +243,8 @@ def __setattr__(self, name, value): def __repr__(self): return ( - f'EntryPoint(name={self.name!r}, value={self.value!r}, ' - f'group={self.group!r})' + f"EntryPoint(name={self.name!r}, value={self.value!r}, " + f"group={self.group!r})" ) def __hash__(self): @@ -298,16 +298,16 @@ def wrapped(self, *args, **kwargs): return wrapped for method_name in [ - '__setitem__', - '__delitem__', - 'append', - 'reverse', - 'extend', - 'pop', - 'remove', - '__iadd__', - 'insert', - 'sort', + "__setitem__", + "__delitem__", + "append", + "reverse", + "extend", + "pop", + "remove", + "__iadd__", + "insert", + "sort", ]: locals()[method_name] = _wrap_deprecated_method(method_name) @@ -382,7 +382,7 @@ def _from_text_for(cls, text, dist): def _from_text(text): return ( EntryPoint(name=item.value.name, value=item.value.value, group=item.name) - for item in Sectioned.section_pairs(text or '') + for item in Sectioned.section_pairs(text or "") ) @@ -449,7 +449,7 @@ class SelectableGroups(Deprecated, dict): @classmethod def load(cls, eps): - by_group = operator.attrgetter('group') + by_group = operator.attrgetter("group") ordered = sorted(eps, key=by_group) grouped = itertools.groupby(ordered, by_group) return cls((group, EntryPoints(eps)) for group, eps in grouped) @@ -484,12 +484,12 @@ def select(self, **params): class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" - def read_text(self, encoding='utf-8'): + def read_text(self, encoding="utf-8"): with self.locate().open(encoding=encoding) as stream: return stream.read() def read_binary(self): - with self.locate().open('rb') as stream: + with self.locate().open("rb") as stream: return stream.read() def locate(self): @@ -499,10 +499,10 @@ def locate(self): class FileHash: def __init__(self, spec): - self.mode, _, self.value = spec.partition('=') + self.mode, _, self.value = spec.partition("=") def __repr__(self): - return f'' + return f"" class Distribution: @@ -551,7 +551,7 @@ def discover(cls, **kwargs): :context: A ``DistributionFinder.Context`` object. :return: Iterable of Distribution objects for all packages. """ - context = kwargs.pop('context', None) + context = kwargs.pop("context", None) if context and kwargs: raise ValueError("cannot accept context and kwargs") context = context or DistributionFinder.Context(**kwargs) @@ -572,12 +572,12 @@ def at(path): def _discover_resolvers(): """Search the meta_path for resolvers.""" declared = ( - getattr(finder, 'find_distributions', None) for finder in sys.meta_path + getattr(finder, "find_distributions", None) for finder in sys.meta_path ) return filter(None, declared) @classmethod - def _local(cls, root='.'): + def _local(cls, root="."): from pep517 import build, meta system = build.compat_system(root) @@ -596,19 +596,19 @@ def metadata(self) -> _meta.PackageMetadata: metadata. See PEP 566 for details. """ text = ( - self.read_text('METADATA') - or self.read_text('PKG-INFO') + self.read_text("METADATA") + or self.read_text("PKG-INFO") # This last clause is here to support old egg-info files. Its # effect is to just end up using the PathDistribution's self._path # (which points to the egg-info file) attribute unchanged. - or self.read_text('') + or self.read_text("") ) return _adapters.Message(email.message_from_string(text)) @property def name(self): """Return the 'Name' metadata for the distribution package.""" - return self.metadata['Name'] + return self.metadata["Name"] @property def _normalized_name(self): @@ -618,11 +618,11 @@ def _normalized_name(self): @property def version(self): """Return the 'Version' metadata for the distribution package.""" - return self.metadata['Version'] + return self.metadata["Version"] @property def entry_points(self): - return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self) + return EntryPoints._from_text_for(self.read_text("entry_points.txt"), self) @property def files(self): @@ -653,7 +653,7 @@ def _read_files_distinfo(self): """ Read the lines of RECORD """ - text = self.read_text('RECORD') + text = self.read_text("RECORD") return text and text.splitlines() def _read_files_egginfo(self): @@ -661,7 +661,7 @@ def _read_files_egginfo(self): SOURCES.txt might contain literal commas, so wrap each line in quotes. """ - text = self.read_text('SOURCES.txt') + text = self.read_text("SOURCES.txt") return text and map('"{}"'.format, text.splitlines()) @property @@ -671,10 +671,10 @@ def requires(self): return reqs and list(reqs) def _read_dist_info_reqs(self): - return self.metadata.get_all('Requires-Dist') + return self.metadata.get_all("Requires-Dist") def _read_egg_info_reqs(self): - source = self.read_text('requires.txt') + source = self.read_text("requires.txt") return source and self._deps_from_requires_text(source) @classmethod @@ -697,12 +697,12 @@ def make_condition(name): return name and f'extra == "{name}"' def quoted_marker(section): - section = section or '' - extra, sep, markers = section.partition(':') + section = section or "" + extra, sep, markers = section.partition(":") if extra and markers: - markers = f'({markers})' + markers = f"({markers})" conditions = list(filter(None, [markers, make_condition(extra)])) - return '; ' + ' and '.join(conditions) if conditions else '' + return "; " + " and ".join(conditions) if conditions else "" def url_req_space(req): """ @@ -710,7 +710,7 @@ def url_req_space(req): Ref python/importlib_metadata#357. """ # '@' is uniquely indicative of a url_req. - return ' ' * ('@' in req) + return " " * ("@" in req) for section in sections: space = url_req_space(section.value) @@ -752,7 +752,7 @@ def path(self): Typically refers to Python installed package paths such as "site-packages" directories and defaults to ``sys.path``. """ - return vars(self).get('path', sys.path) + return vars(self).get("path", sys.path) @abc.abstractmethod def find_distributions(self, context=Context()): @@ -786,7 +786,7 @@ def joinpath(self, child): def children(self): with suppress(Exception): - return os.listdir(self.root or '.') + return os.listdir(self.root or ".") with suppress(Exception): return self.zip_children() return [] @@ -868,7 +868,7 @@ def normalize(name): """ PEP 503 normalization plus dashes as underscores. """ - return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_') + return re.sub(r"[-_.]+", "-", name).lower().replace("-", "_") @staticmethod def legacy_normalize(name): @@ -876,7 +876,7 @@ def legacy_normalize(name): Normalize the package name as found in the convention in older packaging tools versions and specs. """ - return name.lower().replace('-', '_') + return name.lower().replace("-", "_") def __bool__(self): return bool(self.name) @@ -930,7 +930,7 @@ def read_text(self, filename): NotADirectoryError, PermissionError, ): - return self._path.joinpath(filename).read_text(encoding='utf-8') + return self._path.joinpath(filename).read_text(encoding="utf-8") read_text.__doc__ = Distribution.read_text.__doc__ @@ -948,9 +948,9 @@ def _normalized_name(self): def _name_from_stem(self, stem): name, ext = os.path.splitext(stem) - if ext not in ('.dist-info', '.egg-info'): + if ext not in (".dist-info", ".egg-info"): return - name, sep, rest = stem.partition('-') + name, sep, rest = stem.partition("-") return name @@ -1007,7 +1007,7 @@ def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: :return: EntryPoints or SelectableGroups for all installed packages. """ - norm_name = operator.attrgetter('_normalized_name') + norm_name = operator.attrgetter("_normalized_name") unique = functools.partial(unique_everseen, key=norm_name) eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) @@ -1047,17 +1047,17 @@ def packages_distributions() -> Mapping[str, List[str]]: pkg_to_dist = collections.defaultdict(list) for dist in distributions(): for pkg in _top_level_declared(dist) or _top_level_inferred(dist): - pkg_to_dist[pkg].append(dist.metadata['Name']) + pkg_to_dist[pkg].append(dist.metadata["Name"]) return dict(pkg_to_dist) def _top_level_declared(dist): - return (dist.read_text('top_level.txt') or '').split() + return (dist.read_text("top_level.txt") or "").split() def _top_level_inferred(dist): return { - f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + f.parts[0] if len(f.parts) > 1 else f.with_suffix("").name for f in always_iterable(dist.files) if f.suffix == ".py" } diff --git a/metaflow/_vendor/v3_7/importlib_metadata/_adapters.py b/metaflow/_vendor/v3_7/importlib_metadata/_adapters.py index aa460d3eda5..49cfa02e666 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/_adapters.py @@ -10,16 +10,16 @@ class Message(email.message.Message): map( FoldedCase, [ - 'Classifier', - 'Obsoletes-Dist', - 'Platform', - 'Project-URL', - 'Provides-Dist', - 'Provides-Extra', - 'Requires-Dist', - 'Requires-External', - 'Supported-Platform', - 'Dynamic', + "Classifier", + "Obsoletes-Dist", + "Platform", + "Project-URL", + "Provides-Dist", + "Provides-Extra", + "Requires-Dist", + "Requires-External", + "Supported-Platform", + "Dynamic", ], ) ) @@ -42,13 +42,13 @@ def __iter__(self): def _repair_headers(self): def redent(value): "Correct for RFC822 indentation" - if not value or '\n' not in value: + if not value or "\n" not in value: return value - return textwrap.dedent(' ' * 8 + value) + return textwrap.dedent(" " * 8 + value) - headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + headers = [(key, redent(value)) for key, value in vars(self)["_headers"]] if self._payload: - headers.append(('Description', self.get_payload())) + headers.append(("Description", self.get_payload())) return headers @property @@ -60,9 +60,9 @@ def json(self): def transform(key): value = self.get_all(key) if key in self.multiple_use_keys else self[key] - if key == 'Keywords': - value = re.split(r'\s+', value) - tk = key.lower().replace('-', '_') + if key == "Keywords": + value = re.split(r"\s+", value) + tk = key.lower().replace("-", "_") return tk, value return dict(map(transform, map(FoldedCase, self))) diff --git a/metaflow/_vendor/v3_7/importlib_metadata/_collections.py b/metaflow/_vendor/v3_7/importlib_metadata/_collections.py index cf0954e1a30..895678a23c3 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/_collections.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/_collections.py @@ -18,13 +18,13 @@ class FreezableDefaultDict(collections.defaultdict): """ def __missing__(self, key): - return getattr(self, '_frozen', super().__missing__)(key) + return getattr(self, "_frozen", super().__missing__)(key) def freeze(self): self._frozen = lambda key: self.default_factory() -class Pair(collections.namedtuple('Pair', 'name value')): +class Pair(collections.namedtuple("Pair", "name value")): @classmethod def parse(cls, text): return cls(*map(str.strip, text.split("=", 1))) diff --git a/metaflow/_vendor/v3_7/importlib_metadata/_compat.py b/metaflow/_vendor/v3_7/importlib_metadata/_compat.py index 173eebe017c..b1886410e88 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/_compat.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/_compat.py @@ -2,7 +2,7 @@ import platform -__all__ = ['install', 'NullFinder', 'Protocol'] +__all__ = ["install", "NullFinder", "Protocol"] try: @@ -35,8 +35,8 @@ def disable_stdlib_finder(): def matches(finder): return getattr( - finder, '__module__', None - ) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions') + finder, "__module__", None + ) == "_frozen_importlib_external" and hasattr(finder, "find_distributions") for finder in filter(matches, sys.meta_path): # pragma: nocover del finder.find_distributions @@ -67,5 +67,5 @@ def pypy_partial(val): Workaround for #327. """ - is_pypy = platform.python_implementation() == 'PyPy' + is_pypy = platform.python_implementation() == "PyPy" return val + is_pypy diff --git a/metaflow/_vendor/v3_7/importlib_metadata/_meta.py b/metaflow/_vendor/v3_7/importlib_metadata/_meta.py index 37ee43e6ef4..bc502995bd7 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/_meta.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/_meta.py @@ -6,17 +6,13 @@ class PackageMetadata(Protocol): - def __len__(self) -> int: - ... # pragma: no cover + def __len__(self) -> int: ... # pragma: no cover - def __contains__(self, item: str) -> bool: - ... # pragma: no cover + def __contains__(self, item: str) -> bool: ... # pragma: no cover - def __getitem__(self, key: str) -> str: - ... # pragma: no cover + def __getitem__(self, key: str) -> str: ... # pragma: no cover - def __iter__(self) -> Iterator[str]: - ... # pragma: no cover + def __iter__(self) -> Iterator[str]: ... # pragma: no cover def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: """ @@ -35,14 +31,10 @@ class SimplePath(Protocol): A minimal subset of pathlib.Path required by PathDistribution. """ - def joinpath(self) -> 'SimplePath': - ... # pragma: no cover + def joinpath(self) -> "SimplePath": ... # pragma: no cover - def __truediv__(self) -> 'SimplePath': - ... # pragma: no cover + def __truediv__(self) -> "SimplePath": ... # pragma: no cover - def parent(self) -> 'SimplePath': - ... # pragma: no cover + def parent(self) -> "SimplePath": ... # pragma: no cover - def read_text(self) -> str: - ... # pragma: no cover + def read_text(self) -> str: ... # pragma: no cover diff --git a/metaflow/_vendor/v3_7/importlib_metadata/_text.py b/metaflow/_vendor/v3_7/importlib_metadata/_text.py index c88cfbb2349..376210d7096 100644 --- a/metaflow/_vendor/v3_7/importlib_metadata/_text.py +++ b/metaflow/_vendor/v3_7/importlib_metadata/_text.py @@ -94,6 +94,6 @@ def lower(self): def index(self, sub): return self.lower().index(sub.lower()) - def split(self, splitter=' ', maxsplit=0): + def split(self, splitter=" ", maxsplit=0): pattern = re.compile(re.escape(splitter), re.I) return pattern.split(self, maxsplit) diff --git a/metaflow/_vendor/v3_7/typeguard/_decorators.py b/metaflow/_vendor/v3_7/typeguard/_decorators.py index 53f254f7080..cf3253351fe 100644 --- a/metaflow/_vendor/v3_7/typeguard/_decorators.py +++ b/metaflow/_vendor/v3_7/typeguard/_decorators.py @@ -133,13 +133,11 @@ def typechecked( typecheck_fail_callback: TypeCheckFailCallback | Unset = unset, collection_check_strategy: CollectionCheckStrategy | Unset = unset, debug_instrumentation: bool | Unset = unset, -) -> Callable[[T_CallableOrType], T_CallableOrType]: - ... +) -> Callable[[T_CallableOrType], T_CallableOrType]: ... @overload -def typechecked(target: T_CallableOrType) -> T_CallableOrType: - ... +def typechecked(target: T_CallableOrType) -> T_CallableOrType: ... def typechecked( @@ -215,9 +213,9 @@ def typechecked( return target # Find either the first Python wrapper or the actual function - wrapper_class: type[classmethod[Any, Any, Any]] | type[ - staticmethod[Any, Any] - ] | None = None + wrapper_class: ( + type[classmethod[Any, Any, Any]] | type[staticmethod[Any, Any]] | None + ) = None if isinstance(target, (classmethod, staticmethod)): wrapper_class = target.__class__ target = target.__func__ diff --git a/metaflow/_vendor/v3_7/typeguard/_functions.py b/metaflow/_vendor/v3_7/typeguard/_functions.py index 6c64bd19c42..ad0130e5ca8 100644 --- a/metaflow/_vendor/v3_7/typeguard/_functions.py +++ b/metaflow/_vendor/v3_7/typeguard/_functions.py @@ -32,8 +32,7 @@ def check_type( forward_ref_policy: ForwardRefPolicy = ..., typecheck_fail_callback: TypeCheckFailCallback | None = ..., collection_check_strategy: CollectionCheckStrategy = ..., -) -> T: - ... +) -> T: ... @overload @@ -44,8 +43,7 @@ def check_type( forward_ref_policy: ForwardRefPolicy = ..., typecheck_fail_callback: TypeCheckFailCallback | None = ..., collection_check_strategy: CollectionCheckStrategy = ..., -) -> Any: - ... +) -> Any: ... def check_type( @@ -53,7 +51,7 @@ def check_type( expected_type: Any, *, forward_ref_policy: ForwardRefPolicy = TypeCheckConfiguration().forward_ref_policy, - typecheck_fail_callback: (TypeCheckFailCallback | None) = ( + typecheck_fail_callback: TypeCheckFailCallback | None = ( TypeCheckConfiguration().typecheck_fail_callback ), collection_check_strategy: CollectionCheckStrategy = ( diff --git a/metaflow/_vendor/v3_7/typeguard/_memo.py b/metaflow/_vendor/v3_7/typeguard/_memo.py index 2eb8e62efae..b20291b3f4f 100644 --- a/metaflow/_vendor/v3_7/typeguard/_memo.py +++ b/metaflow/_vendor/v3_7/typeguard/_memo.py @@ -2,7 +2,10 @@ from typing import Any -from metaflow._vendor.v3_7.typeguard._config import TypeCheckConfiguration, global_config +from metaflow._vendor.v3_7.typeguard._config import ( + TypeCheckConfiguration, + global_config, +) class TypeCheckMemo: diff --git a/metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py b/metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py index fc7650bc9a9..6d6000c20d0 100644 --- a/metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +++ b/metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py @@ -5,7 +5,11 @@ from pytest import Config, Parser -from metaflow._vendor.v3_7.typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from metaflow._vendor.v3_7.typeguard._config import ( + CollectionCheckStrategy, + ForwardRefPolicy, + global_config, +) from metaflow._vendor.v3_7.typeguard._exceptions import InstrumentationWarning from metaflow._vendor.v3_7.typeguard._importhook import install_import_hook from metaflow._vendor.v3_7.typeguard._utils import qualified_name, resolve_reference diff --git a/metaflow/_vendor/v3_7/typeguard/_suppression.py b/metaflow/_vendor/v3_7/typeguard/_suppression.py index 44f5c4088c8..23876ea6770 100644 --- a/metaflow/_vendor/v3_7/typeguard/_suppression.py +++ b/metaflow/_vendor/v3_7/typeguard/_suppression.py @@ -20,17 +20,15 @@ @overload -def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: - ... +def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: ... @overload -def suppress_type_checks() -> ContextManager[None]: - ... +def suppress_type_checks() -> ContextManager[None]: ... def suppress_type_checks( - func: Callable[P, T] | None = None + func: Callable[P, T] | None = None, ) -> Callable[P, T] | ContextManager[None]: """ Temporarily suppress all type checking. diff --git a/metaflow/_vendor/v3_7/typeguard/_transformer.py b/metaflow/_vendor/v3_7/typeguard/_transformer.py index 24090b19b00..0df7dab8354 100644 --- a/metaflow/_vendor/v3_7/typeguard/_transformer.py +++ b/metaflow/_vendor/v3_7/typeguard/_transformer.py @@ -577,12 +577,10 @@ def _get_import(self, module: str, name: str) -> Name: return memo.get_import(module, name) @overload - def _convert_annotation(self, annotation: None) -> None: - ... + def _convert_annotation(self, annotation: None) -> None: ... @overload - def _convert_annotation(self, annotation: expr) -> expr: - ... + def _convert_annotation(self, annotation: expr) -> expr: ... def _convert_annotation(self, annotation: expr | None) -> expr | None: if annotation is None: diff --git a/metaflow/_vendor/v3_7/typeguard/_union_transformer.py b/metaflow/_vendor/v3_7/typeguard/_union_transformer.py index fcd6349d35a..19617e6af5a 100644 --- a/metaflow/_vendor/v3_7/typeguard/_union_transformer.py +++ b/metaflow/_vendor/v3_7/typeguard/_union_transformer.py @@ -2,6 +2,7 @@ Transforms lazily evaluated PEP 604 unions into typing.Unions, for compatibility with Python versions older than 3.10. """ + from __future__ import annotations from ast import ( diff --git a/metaflow/_vendor/v3_7/typing_extensions.py b/metaflow/_vendor/v3_7/typing_extensions.py index 6b7dc6cc103..a1ca7df5acb 100644 --- a/metaflow/_vendor/v3_7/typing_extensions.py +++ b/metaflow/_vendor/v3_7/typing_extensions.py @@ -11,121 +11,116 @@ __all__ = [ # Super-special typing primitives. - 'Any', - 'ClassVar', - 'Concatenate', - 'Final', - 'LiteralString', - 'ParamSpec', - 'ParamSpecArgs', - 'ParamSpecKwargs', - 'Self', - 'Type', - 'TypeVar', - 'TypeVarTuple', - 'Unpack', - + "Any", + "ClassVar", + "Concatenate", + "Final", + "LiteralString", + "ParamSpec", + "ParamSpecArgs", + "ParamSpecKwargs", + "Self", + "Type", + "TypeVar", + "TypeVarTuple", + "Unpack", # ABCs (from collections.abc). - 'Awaitable', - 'AsyncIterator', - 'AsyncIterable', - 'Coroutine', - 'AsyncGenerator', - 'AsyncContextManager', - 'Buffer', - 'ChainMap', - + "Awaitable", + "AsyncIterator", + "AsyncIterable", + "Coroutine", + "AsyncGenerator", + "AsyncContextManager", + "Buffer", + "ChainMap", # Concrete collection types. - 'ContextManager', - 'Counter', - 'Deque', - 'DefaultDict', - 'NamedTuple', - 'OrderedDict', - 'TypedDict', - + "ContextManager", + "Counter", + "Deque", + "DefaultDict", + "NamedTuple", + "OrderedDict", + "TypedDict", # Structural checks, a.k.a. protocols. - 'SupportsAbs', - 'SupportsBytes', - 'SupportsComplex', - 'SupportsFloat', - 'SupportsIndex', - 'SupportsInt', - 'SupportsRound', - + "SupportsAbs", + "SupportsBytes", + "SupportsComplex", + "SupportsFloat", + "SupportsIndex", + "SupportsInt", + "SupportsRound", # One-off things. - 'Annotated', - 'assert_never', - 'assert_type', - 'clear_overloads', - 'dataclass_transform', - 'deprecated', - 'get_overloads', - 'final', - 'get_args', - 'get_origin', - 'get_original_bases', - 'get_protocol_members', - 'get_type_hints', - 'IntVar', - 'is_protocol', - 'is_typeddict', - 'Literal', - 'NewType', - 'overload', - 'override', - 'Protocol', - 'reveal_type', - 'runtime', - 'runtime_checkable', - 'Text', - 'TypeAlias', - 'TypeAliasType', - 'TypeGuard', - 'TYPE_CHECKING', - 'Never', - 'NoReturn', - 'Required', - 'NotRequired', - + "Annotated", + "assert_never", + "assert_type", + "clear_overloads", + "dataclass_transform", + "deprecated", + "get_overloads", + "final", + "get_args", + "get_origin", + "get_original_bases", + "get_protocol_members", + "get_type_hints", + "IntVar", + "is_protocol", + "is_typeddict", + "Literal", + "NewType", + "overload", + "override", + "Protocol", + "reveal_type", + "runtime", + "runtime_checkable", + "Text", + "TypeAlias", + "TypeAliasType", + "TypeGuard", + "TYPE_CHECKING", + "Never", + "NoReturn", + "Required", + "NotRequired", # Pure aliases, have always been in typing - 'AbstractSet', - 'AnyStr', - 'BinaryIO', - 'Callable', - 'Collection', - 'Container', - 'Dict', - 'ForwardRef', - 'FrozenSet', - 'Generator', - 'Generic', - 'Hashable', - 'IO', - 'ItemsView', - 'Iterable', - 'Iterator', - 'KeysView', - 'List', - 'Mapping', - 'MappingView', - 'Match', - 'MutableMapping', - 'MutableSequence', - 'MutableSet', - 'Optional', - 'Pattern', - 'Reversible', - 'Sequence', - 'Set', - 'Sized', - 'TextIO', - 'Tuple', - 'Union', - 'ValuesView', - 'cast', - 'no_type_check', - 'no_type_check_decorator', + "AbstractSet", + "AnyStr", + "BinaryIO", + "Callable", + "Collection", + "Container", + "Dict", + "ForwardRef", + "FrozenSet", + "Generator", + "Generic", + "Hashable", + "IO", + "ItemsView", + "Iterable", + "Iterator", + "KeysView", + "List", + "Mapping", + "MappingView", + "Match", + "MutableMapping", + "MutableSequence", + "MutableSet", + "Optional", + "Pattern", + "Reversible", + "Sequence", + "Set", + "Sized", + "TextIO", + "Tuple", + "Union", + "ValuesView", + "cast", + "no_type_check", + "no_type_check_decorator", ] # for backward compatibility @@ -161,19 +156,26 @@ def _check_generic(cls, parameters, elen=_marker): num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): return - raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" - f" actual {alen}, expected {elen}") + raise TypeError( + f"Too {'many' if alen > elen else 'few'} parameters for {cls};" + f" actual {alen}, expected {elen}" + ) if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): return isinstance( t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) ) + elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) + else: + def _should_collect_from_parameters(t): return isinstance(t, typing._GenericAlias) and not t._special @@ -188,11 +190,7 @@ def _collect_type_vars(types, typevar_types=None): typevar_types = typing.TypeVar tvars = [] for t in types: - if ( - isinstance(t, typevar_types) and - t not in tvars and - not _is_unpack(t) - ): + if isinstance(t, typevar_types) and t not in tvars and not _is_unpack(t): tvars.append(t) if _should_collect_from_parameters(t): tvars.extend([t for t in t.__parameters__ if t not in tvars]) @@ -203,11 +201,11 @@ def _collect_type_vars(types, typevar_types=None): # Some unconstrained type variables. These are used by the container types. # (These are not for export.) -T = typing.TypeVar('T') # Any type. -KT = typing.TypeVar('KT') # Key type. -VT = typing.TypeVar('VT') # Value type. -T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. -T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. +T = typing.TypeVar("T") # Any type. +KT = typing.TypeVar("KT") # Key type. +VT = typing.TypeVar("VT") # Value type. +T_co = typing.TypeVar("T_co", covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar("T_contra", contravariant=True) # Ditto contravariant. if sys.version_info >= (3, 11): @@ -217,7 +215,9 @@ def _collect_type_vars(types, typevar_types=None): class _AnyMeta(type): def __instancecheck__(self, obj): if self is Any: - raise TypeError("typing_extensions.Any cannot be used with isinstance()") + raise TypeError( + "typing_extensions.Any cannot be used with isinstance()" + ) return super().__instancecheck__(obj) def __repr__(self): @@ -234,6 +234,7 @@ class Any(metaclass=_AnyMeta): static type checkers. At runtime, Any should not be used with instance checks. """ + def __new__(cls, *args, **kwargs): if cls is Any: raise TypeError("Any cannot be instantiated") @@ -245,23 +246,26 @@ def __new__(cls, *args, **kwargs): class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): def __repr__(self): - return 'typing_extensions.' + self._name + return "typing_extensions." + self._name # On older versions of typing there is an internal class named "Final". # 3.8+ -if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): +if hasattr(typing, "Final") and sys.version_info[:2] >= (3, 7): Final = typing.Final # 3.7 else: + class _FinalForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) - Final = _FinalForm('Final', - doc="""A special typing construct to indicate that a name + Final = _FinalForm( + "Final", + doc="""A special typing construct to indicate that a name cannot be re-assigned or overridden in a subclass. For example: @@ -273,7 +277,8 @@ class Connection: class FastConnector(Connection): TIMEOUT = 1 # Error reported by type checker - There is no runtime checking of these properties.""") + There is no runtime checking of these properties.""", + ) if sys.version_info >= (3, 11): final = typing.final @@ -321,6 +326,7 @@ def IntVar(name): if sys.version_info >= (3, 10, 1): Literal = typing.Literal else: + def _flatten_literal_params(parameters): """An internal helper for Literal creation: flatten Literals among parameters""" params = [] @@ -348,7 +354,7 @@ def __hash__(self): class _LiteralForm(_ExtensionsSpecialForm, _root=True): def __init__(self, doc: str): - self._name = 'Literal' + self._name = "Literal" self._doc = self.__doc__ = doc def __getitem__(self, parameters): @@ -376,7 +382,8 @@ def __getitem__(self, parameters): return _LiteralGenericAlias(self, parameters) - Literal = _LiteralForm(doc="""\ + Literal = _LiteralForm( + doc="""\ A type that can be used to indicate to type checkers that the corresponding value has a value literally equivalent to the provided parameter. For example: @@ -388,7 +395,8 @@ def __getitem__(self, parameters): Literal[...] cannot be subclassed. There is no runtime checking verifying that the parameter is actually a value - instead of a type.""") + instead of a type.""" + ) _overload_dummy = typing._overload_dummy @@ -477,7 +485,7 @@ def clear_overloads(): DefaultDict = typing.DefaultDict # 3.7.2+ -if hasattr(typing, 'OrderedDict'): +if hasattr(typing, "OrderedDict"): OrderedDict = typing.OrderedDict # 3.7.0-3.7.2 else: @@ -491,27 +499,53 @@ def clear_overloads(): _PROTO_ALLOWLIST = { - 'collections.abc': [ - 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Buffer', + "collections.abc": [ + "Callable", + "Awaitable", + "Iterable", + "Iterator", + "AsyncIterable", + "Hashable", + "Sized", + "Container", + "Collection", + "Reversible", + "Buffer", ], - 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], - 'typing_extensions': ['Buffer'], + "contextlib": ["AbstractContextManager", "AbstractAsyncContextManager"], + "typing_extensions": ["Buffer"], } _EXCLUDED_ATTRS = { - "__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol", - "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", - "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", - "__subclasshook__", "__orig_class__", "__init__", "__new__", - "__protocol_attrs__", "__callable_proto_members_only__", + "__abstractmethods__", + "__annotations__", + "__weakref__", + "_is_protocol", + "_is_runtime_protocol", + "__dict__", + "__slots__", + "__parameters__", + "__orig_bases__", + "__module__", + "_MutableMapping__marker", + "__doc__", + "__subclasshook__", + "__orig_class__", + "__init__", + "__new__", + "__protocol_attrs__", + "__callable_proto_members_only__", } if sys.version_info < (3, 8): _EXCLUDED_ATTRS |= { - "_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__", - "__origin__" + "_gorg", + "__next_in_mro__", + "__extra__", + "__tree_hash__", + "__args__", + "__origin__", } if sys.version_info >= (3, 9): @@ -526,11 +560,11 @@ def clear_overloads(): def _get_protocol_attrs(cls): attrs = set() for base in cls.__mro__[:-1]: # without object - if base.__name__ in {'Protocol', 'Generic'}: + if base.__name__ in {"Protocol", "Generic"}: continue - annotations = getattr(base, '__annotations__', {}) + annotations = getattr(base, "__annotations__", {}) for attr in (*base.__dict__, *annotations): - if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): + if not attr.startswith("_abc_") and attr not in _EXCLUDED_ATTRS: attrs.add(attr) return attrs @@ -543,7 +577,7 @@ def _maybe_adjust_parameters(cls): on the CPython main branch. """ tvars = [] - if '__orig_bases__' in cls.__dict__: + if "__orig_bases__" in cls.__dict__: tvars = _collect_type_vars(cls.__orig_bases__) # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. # If found, tvars must be a subset of it. @@ -552,14 +586,17 @@ def _maybe_adjust_parameters(cls): # and reject multiple Generic[...] and/or Protocol[...]. gvars = None for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): + if isinstance(base, typing._GenericAlias) and base.__origin__ in ( + typing.Generic, + Protocol, + ): # for error messages the_base = base.__origin__.__name__ if gvars is not None: raise TypeError( "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") + " and/or Protocol[...] multiple types." + ) gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -567,17 +604,19 @@ def _maybe_adjust_parameters(cls): tvarset = set(tvars) gvarset = set(gvars) if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") + s_vars = ", ".join(str(t) for t in tvars if t not in gvarset) + s_args = ", ".join(str(g) for g in gvars) + raise TypeError( + f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]" + ) tvars = gvars cls.__parameters__ = tuple(tvars) def _caller(depth=2): try: - return sys._getframe(depth).f_globals.get('__name__', '__main__') + return sys._getframe(depth).f_globals.get("__name__", "__main__") except (AttributeError, ValueError): # For platforms without _getframe() return None @@ -587,16 +626,17 @@ def _caller(depth=2): if sys.version_info >= (3, 12): Protocol = typing.Protocol else: + def _allow_reckless_class_checks(depth=3): """Allow instance and class checks for special stdlib modules. The abc and functools modules indiscriminately call isinstance() and issubclass() on the whole MRO of a user class, which may contain protocols. """ - return _caller(depth) in {'abc', 'functools', None} + return _caller(depth) in {"abc", "functools", None} def _no_init(self, *args, **kwargs): if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') + raise TypeError("Protocols cannot be instantiated") if sys.version_info >= (3, 8): # Inheriting from typing._ProtocolMeta isn't actually desirable, @@ -638,19 +678,20 @@ def __init__(cls, *args, **kwargs): # PEP 544 prohibits using issubclass() # with protocols that have non-method members. cls.__callable_proto_members_only__ = all( - callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__ + callable(getattr(cls, attr, None)) + for attr in cls.__protocol_attrs__ ) def __subclasscheck__(cls, other): if cls is Protocol: return type.__subclasscheck__(cls, other) if ( - getattr(cls, '_is_protocol', False) + getattr(cls, "_is_protocol", False) and not _allow_reckless_class_checks() ): if not isinstance(other, type): # Same error message as for issubclass(1, int). - raise TypeError('issubclass() arg 1 must be a class') + raise TypeError("issubclass() arg 1 must be a class") if ( not cls.__callable_proto_members_only__ and cls.__dict__.get("__subclasshook__") is _proto_hook @@ -658,7 +699,7 @@ def __subclasscheck__(cls, other): raise TypeError( "Protocols with non-method members don't support issubclass()" ) - if not getattr(cls, '_is_runtime_protocol', False): + if not getattr(cls, "_is_runtime_protocol", False): raise TypeError( "Instance and class checks can only be used with " "@runtime_checkable protocols" @@ -675,11 +716,13 @@ def __instancecheck__(cls, instance): return abc.ABCMeta.__instancecheck__(cls, instance) if ( - not getattr(cls, '_is_runtime_protocol', False) and - not _allow_reckless_class_checks() + not getattr(cls, "_is_runtime_protocol", False) + and not _allow_reckless_class_checks() ): - raise TypeError("Instance and class checks can only be used with" - " @runtime_checkable protocols") + raise TypeError( + "Instance and class checks can only be used with" + " @runtime_checkable protocols" + ) if abc.ABCMeta.__instancecheck__(cls, instance): return True @@ -702,9 +745,7 @@ def __eq__(cls, other): # as equivalent to typing.Protocol on Python 3.8+ if abc.ABCMeta.__eq__(cls, other) is True: return True - return ( - cls is Protocol and other is getattr(typing, "Protocol", object()) - ) + return cls is Protocol and other is getattr(typing, "Protocol", object()) # This has to be defined, or the abc-module cache # complains about classes with this metaclass being unhashable, @@ -714,7 +755,7 @@ def __hash__(cls) -> int: @classmethod def _proto_hook(cls, other): - if not cls.__dict__.get('_is_protocol', False): + if not cls.__dict__.get("_is_protocol", False): return NotImplemented for attr in cls.__protocol_attrs__: @@ -726,7 +767,7 @@ def _proto_hook(cls, other): break # ...or in annotations, if it is a sub-protocol. - annotations = getattr(base, '__annotations__', {}) + annotations = getattr(base, "__annotations__", {}) if ( isinstance(annotations, collections.abc.Mapping) and attr in annotations @@ -738,6 +779,7 @@ def _proto_hook(cls, other): return True if sys.version_info >= (3, 8): + class Protocol(typing.Generic, metaclass=_ProtocolMeta): __doc__ = typing.Protocol.__doc__ __slots__ = () @@ -748,11 +790,11 @@ def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', False): + if not cls.__dict__.get("_is_protocol", False): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. - if '__subclasshook__' not in cls.__dict__: + if "__subclasshook__" not in cls.__dict__: cls.__subclasshook__ = _proto_hook # Prohibit instantiation for protocol classes @@ -760,6 +802,7 @@ def __init_subclass__(cls, *args, **kwargs): cls.__init__ = _no_init else: + class Protocol(metaclass=_ProtocolMeta): # There is quite a lot of overlapping code with typing.Generic. # Unfortunately it is hard to avoid this on Python <3.8, @@ -799,8 +842,10 @@ def meth(self) -> T: def __new__(cls, *args, **kwds): if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") + raise TypeError( + "Type Protocol cannot be instantiated; " + "it can only be used as a base class" + ) return super().__new__(cls) @typing._tp_cache @@ -809,7 +854,8 @@ def __class_getitem__(cls, params): params = (params,) if not params and cls is not typing.Tuple: raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") + f"Parameter list to {cls.__qualname__}[...] cannot be empty" + ) msg = "Parameters to generic types must be types." params = tuple(typing._type_check(p, msg) for p in params) if cls is Protocol: @@ -820,17 +866,19 @@ def __class_getitem__(cls, params): i += 1 raise TypeError( "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") + f" Parameter {i + 1} is {params[i]}" + ) if len(set(params)) != len(params): raise TypeError( - "Parameters to Protocol[...] must all be unique") + "Parameters to Protocol[...] must all be unique" + ) else: # Subscripting a regular Generic subclass. _check_generic(cls, params, len(cls.__parameters__)) return typing._GenericAlias(cls, params) def __init_subclass__(cls, *args, **kwargs): - if '__orig_bases__' in cls.__dict__: + if "__orig_bases__" in cls.__dict__: error = typing.Generic in cls.__orig_bases__ else: error = typing.Generic in cls.__bases__ @@ -839,11 +887,11 @@ def __init_subclass__(cls, *args, **kwargs): _maybe_adjust_parameters(cls) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get("_is_protocol", None): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. - if '__subclasshook__' not in cls.__dict__: + if "__subclasshook__" not in cls.__dict__: cls.__subclasshook__ = _proto_hook # Prohibit instantiation for protocol classes @@ -854,6 +902,7 @@ def __init_subclass__(cls, *args, **kwargs): if sys.version_info >= (3, 8): runtime_checkable = typing.runtime_checkable else: + def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it can be used with isinstance() and issubclass(). Raise TypeError @@ -866,8 +915,10 @@ def runtime_checkable(cls): (isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic)) and getattr(cls, "_is_protocol", False) ): - raise TypeError('@runtime_checkable can be only applied to protocol classes,' - f' got {cls!r}') + raise TypeError( + "@runtime_checkable can be only applied to protocol classes," + f" got {cls!r}" + ) cls._is_runtime_protocol = True return cls @@ -886,9 +937,11 @@ def runtime_checkable(cls): SupportsAbs = typing.SupportsAbs SupportsRound = typing.SupportsRound else: + @runtime_checkable class SupportsInt(Protocol): """An ABC with one abstract method __int__.""" + __slots__ = () @abc.abstractmethod @@ -898,6 +951,7 @@ def __int__(self) -> int: @runtime_checkable class SupportsFloat(Protocol): """An ABC with one abstract method __float__.""" + __slots__ = () @abc.abstractmethod @@ -907,6 +961,7 @@ def __float__(self) -> float: @runtime_checkable class SupportsComplex(Protocol): """An ABC with one abstract method __complex__.""" + __slots__ = () @abc.abstractmethod @@ -916,6 +971,7 @@ def __complex__(self) -> complex: @runtime_checkable class SupportsBytes(Protocol): """An ABC with one abstract method __bytes__.""" + __slots__ = () @abc.abstractmethod @@ -935,6 +991,7 @@ class SupportsAbs(Protocol[T_co]): """ An ABC with one abstract method __abs__ that is covariant in its return type. """ + __slots__ = () @abc.abstractmethod @@ -946,6 +1003,7 @@ class SupportsRound(Protocol[T_co]): """ An ABC with one abstract method __round__ that is covariant in its return type. """ + __slots__ = () @abc.abstractmethod @@ -958,13 +1016,14 @@ def inner(func): if sys.implementation.name == "pypy" and sys.version_info < (3, 9): cls_dict = { "__call__": staticmethod(func), - "__mro_entries__": staticmethod(mro_entries) + "__mro_entries__": staticmethod(mro_entries), } t = type(func.__name__, (), cls_dict) return functools.update_wrapper(t(), func) else: func.__mro_entries__ = mro_entries return func + return inner @@ -1002,8 +1061,10 @@ def __new__(cls, name, bases, ns, total=True): """ for base in bases: if type(base) is not _TypedDictMeta and base is not typing.Generic: - raise TypeError('cannot inherit from both a TypedDict type ' - 'and a non-TypedDict base class') + raise TypeError( + "cannot inherit from both a TypedDict type " + "and a non-TypedDict base class" + ) if any(issubclass(b, typing.Generic) for b in bases): generic_base = (typing.Generic,) @@ -1012,16 +1073,18 @@ def __new__(cls, name, bases, ns, total=True): # typing.py generally doesn't let you inherit from plain Generic, unless # the name of the class happens to be "Protocol" (or "_Protocol" on 3.7). - tp_dict = type.__new__(_TypedDictMeta, _fake_name, (*generic_base, dict), ns) + tp_dict = type.__new__( + _TypedDictMeta, _fake_name, (*generic_base, dict), ns + ) tp_dict.__name__ = name if tp_dict.__qualname__ == _fake_name: tp_dict.__qualname__ = name - if not hasattr(tp_dict, '__orig_bases__'): + if not hasattr(tp_dict, "__orig_bases__"): tp_dict.__orig_bases__ = bases annotations = {} - own_annotations = ns.get('__annotations__', {}) + own_annotations = ns.get("__annotations__", {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" if _TAKES_MODULE: own_annotations = { @@ -1030,16 +1093,15 @@ def __new__(cls, name, bases, ns, total=True): } else: own_annotations = { - n: typing._type_check(tp, msg) - for n, tp in own_annotations.items() + n: typing._type_check(tp, msg) for n, tp in own_annotations.items() } required_keys = set() optional_keys = set() for base in bases: - annotations.update(base.__dict__.get('__annotations__', {})) - required_keys.update(base.__dict__.get('__required_keys__', ())) - optional_keys.update(base.__dict__.get('__optional_keys__', ())) + annotations.update(base.__dict__.get("__annotations__", {})) + required_keys.update(base.__dict__.get("__required_keys__", ())) + optional_keys.update(base.__dict__.get("__optional_keys__", ())) annotations.update(own_annotations) for annotation_key, annotation_type in own_annotations.items(): @@ -1062,7 +1124,7 @@ def __new__(cls, name, bases, ns, total=True): tp_dict.__annotations__ = annotations tp_dict.__required_keys__ = frozenset(required_keys) tp_dict.__optional_keys__ = frozenset(optional_keys) - if not hasattr(tp_dict, '__total__'): + if not hasattr(tp_dict, "__total__"): tp_dict.__total__ = total return tp_dict @@ -1070,11 +1132,11 @@ def __new__(cls, name, bases, ns, total=True): def __subclasscheck__(cls, other): # Typed dicts are only for static structural subtyping. - raise TypeError('TypedDict does not support instance and class checks') + raise TypeError("TypedDict does not support instance and class checks") __instancecheck__ = __subclasscheck__ - _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) + _TypedDict = type.__new__(_TypedDictMeta, "TypedDict", (), {}) @_ensure_subclassable(lambda bases: (_TypedDict,)) def TypedDict(__typename, __fields=_marker, *, total=True, **kwargs): @@ -1132,15 +1194,20 @@ class Point2D(TypedDict): example = f"`{__typename} = TypedDict({__typename!r}, {{}})`" deprecation_msg = ( - f"{deprecated_thing} is deprecated and will be disallowed in " - "Python 3.15. To create a TypedDict class with 0 fields " - "using the functional syntax, pass an empty dictionary, e.g. " - ) + example + "." + ( + f"{deprecated_thing} is deprecated and will be disallowed in " + "Python 3.15. To create a TypedDict class with 0 fields " + "using the functional syntax, pass an empty dictionary, e.g. " + ) + + example + + "." + ) warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) __fields = kwargs elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") + raise TypeError( + "TypedDict takes either a dict or keyword arguments," " but not both" + ) if kwargs: warnings.warn( "The kwargs-based syntax for TypedDict definitions is deprecated " @@ -1150,11 +1217,11 @@ class Point2D(TypedDict): stacklevel=2, ) - ns = {'__annotations__': dict(__fields)} + ns = {"__annotations__": dict(__fields)} module = _caller() if module is not None: # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = module + ns["__module__"] = module td = _TypedDictMeta(__typename, (), ns, total=total) td.__orig_bases__ = (TypedDict,) @@ -1186,6 +1253,7 @@ class Film(TypedDict): assert_type = typing.assert_type else: + def assert_type(__val, __typ): """Assert (to the type checker) that the value is of the given type. @@ -1274,13 +1342,14 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): # Python 3.9+ has PEP 593 (Annotated) -if hasattr(typing, 'Annotated'): +if hasattr(typing, "Annotated"): Annotated = typing.Annotated # Not exported and not a public API, but needed for get_origin() and get_args() # to work. _AnnotatedAlias = typing._AnnotatedAlias # 3.7-3.8 else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): """Runtime representation of an annotated type. @@ -1289,6 +1358,7 @@ class _AnnotatedAlias(typing._GenericAlias, _root=True): instantiating is the same as instantiating the underlying type, binding it to types is also the same. """ + def __init__(self, origin, metadata): if isinstance(origin, _AnnotatedAlias): metadata = origin.__metadata__ + metadata @@ -1302,13 +1372,13 @@ def copy_with(self, params): return _AnnotatedAlias(new_type, self.__metadata__) def __repr__(self): - return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " - f"{', '.join(repr(a) for a in self.__metadata__)}]") + return ( + f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]" + ) def __reduce__(self): - return operator.getitem, ( - Annotated, (self.__origin__,) + self.__metadata__ - ) + return operator.getitem, (Annotated, (self.__origin__,) + self.__metadata__) def __eq__(self, other): if not isinstance(other, _AnnotatedAlias): @@ -1361,9 +1431,11 @@ def __new__(cls, *args, **kwargs): @typing._tp_cache def __class_getitem__(cls, params): if not isinstance(params, tuple) or len(params) < 2: - raise TypeError("Annotated[...] should be used " - "with at least two arguments (a type and an " - "annotation).") + raise TypeError( + "Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation)." + ) allowed_special_forms = (ClassVar, Final) if get_origin(params[0]) in allowed_special_forms: origin = params[0] @@ -1374,9 +1446,8 @@ def __class_getitem__(cls, params): return _AnnotatedAlias(origin, metadata) def __init_subclass__(cls, *args, **kwargs): - raise TypeError( - f"Cannot subclass {cls.__module__}.Annotated" - ) + raise TypeError(f"Cannot subclass {cls.__module__}.Annotated") + # Python 3.8 has get_origin() and get_args() but those implementations aren't # Annotated-aware, so we can't use those. Python 3.9's versions don't support @@ -1414,8 +1485,16 @@ def get_origin(tp): """ if isinstance(tp, _AnnotatedAlias): return Annotated - if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, - ParamSpecArgs, ParamSpecKwargs)): + if isinstance( + tp, + ( + typing._GenericAlias, + _typing_GenericAlias, + _BaseGenericAlias, + ParamSpecArgs, + ParamSpecKwargs, + ), + ): return tp.__origin__ if tp is typing.Generic: return typing.Generic @@ -1445,10 +1524,11 @@ def get_args(tp): # 3.10+ -if hasattr(typing, 'TypeAlias'): +if hasattr(typing, "TypeAlias"): TypeAlias = typing.TypeAlias # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def TypeAlias(self, parameters): """Special marker indicating that an assignment should @@ -1462,10 +1542,12 @@ def TypeAlias(self, parameters): It's invalid when used anywhere except as in the example above. """ raise TypeError(f"{self} is not subscriptable") + + # 3.7-3.8 else: TypeAlias = _ExtensionsSpecialForm( - 'TypeAlias', + "TypeAlias", doc="""Special marker indicating that an assignment should be recognized as a proper type alias definition by type checkers. @@ -1475,14 +1557,15 @@ def TypeAlias(self, parameters): Predicate: TypeAlias = Callable[..., bool] It's invalid when used anywhere except as in the example - above.""" + above.""", ) def _set_default(type_param, default): if isinstance(default, (tuple, list)): - type_param.__default__ = tuple((typing._type_check(d, "Default must be a type") - for d in default)) + type_param.__default__ = tuple( + (typing._type_check(d, "Default must be a type") for d in default) + ) elif default != _marker: type_param.__default__ = typing._type_check(default, "Default must be a type") else: @@ -1492,7 +1575,7 @@ def _set_default(type_param, default): def _set_module(typevarlike): # for pickling: def_mod = _caller(depth=3) - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": typevarlike.__module__ = def_mod @@ -1515,17 +1598,34 @@ class TypeVar(metaclass=_TypeVarLikeMeta): _backported_typevarlike = typing.TypeVar - def __new__(cls, name, *constraints, bound=None, - covariant=False, contravariant=False, - default=_marker, infer_variance=False): + def __new__( + cls, + name, + *constraints, + bound=None, + covariant=False, + contravariant=False, + default=_marker, + infer_variance=False, + ): if hasattr(typing, "TypeAliasType"): # PEP 695 implemented, can pass infer_variance to typing.TypeVar - typevar = typing.TypeVar(name, *constraints, bound=bound, - covariant=covariant, contravariant=contravariant, - infer_variance=infer_variance) + typevar = typing.TypeVar( + name, + *constraints, + bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance, + ) else: - typevar = typing.TypeVar(name, *constraints, bound=bound, - covariant=covariant, contravariant=contravariant) + typevar = typing.TypeVar( + name, + *constraints, + bound=bound, + covariant=covariant, + contravariant=contravariant, + ) if infer_variance and (covariant or contravariant): raise ValueError("Variance cannot be specified with infer_variance.") typevar.__infer_variance__ = infer_variance @@ -1538,13 +1638,15 @@ def __init_subclass__(cls) -> None: # Python 3.10+ has PEP 612 -if hasattr(typing, 'ParamSpecArgs'): +if hasattr(typing, "ParamSpecArgs"): ParamSpecArgs = typing.ParamSpecArgs ParamSpecKwargs = typing.ParamSpecKwargs # 3.7-3.9 else: + class _Immutable: """Mixin to indicate that object should not be copied.""" + __slots__ = () def __copy__(self): @@ -1565,6 +1667,7 @@ class ParamSpecArgs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1588,6 +1691,7 @@ class ParamSpecKwargs(_Immutable): This type is meant for runtime introspection and has no special meaning to static type checkers. """ + def __init__(self, origin): self.__origin__ = origin @@ -1599,8 +1703,9 @@ def __eq__(self, other): return NotImplemented return self.__origin__ == other.__origin__ + # 3.10+ -if hasattr(typing, 'ParamSpec'): +if hasattr(typing, "ParamSpec"): # Add default parameter - PEP 696 class ParamSpec(metaclass=_TypeVarLikeMeta): @@ -1608,19 +1713,29 @@ class ParamSpec(metaclass=_TypeVarLikeMeta): _backported_typevarlike = typing.ParamSpec - def __new__(cls, name, *, bound=None, - covariant=False, contravariant=False, - infer_variance=False, default=_marker): + def __new__( + cls, + name, + *, + bound=None, + covariant=False, + contravariant=False, + infer_variance=False, + default=_marker, + ): if hasattr(typing, "TypeAliasType"): # PEP 695 implemented, can pass infer_variance to typing.TypeVar - paramspec = typing.ParamSpec(name, bound=bound, - covariant=covariant, - contravariant=contravariant, - infer_variance=infer_variance) + paramspec = typing.ParamSpec( + name, + bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance, + ) else: - paramspec = typing.ParamSpec(name, bound=bound, - covariant=covariant, - contravariant=contravariant) + paramspec = typing.ParamSpec( + name, bound=bound, covariant=covariant, contravariant=contravariant + ) paramspec.__infer_variance__ = infer_variance _set_default(paramspec, default) @@ -1628,7 +1743,10 @@ def __new__(cls, name, *, bound=None, return paramspec def __init_subclass__(cls) -> None: - raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + raise TypeError( + f"type '{__name__}.ParamSpec' is not an acceptable base type" + ) + # 3.7-3.9 else: @@ -1692,33 +1810,41 @@ def args(self): def kwargs(self): return ParamSpecKwargs(self) - def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - infer_variance=False, default=_marker): + def __init__( + self, + name, + *, + bound=None, + covariant=False, + contravariant=False, + infer_variance=False, + default=_marker, + ): super().__init__([self]) self.__name__ = name self.__covariant__ = bool(covariant) self.__contravariant__ = bool(contravariant) self.__infer_variance__ = bool(infer_variance) if bound: - self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + self.__bound__ = typing._type_check(bound, "Bound must be a type.") else: self.__bound__ = None _DefaultMixin.__init__(self, default) # for pickling: def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod def __repr__(self): if self.__infer_variance__: - prefix = '' + prefix = "" elif self.__covariant__: - prefix = '+' + prefix = "+" elif self.__contravariant__: - prefix = '-' + prefix = "-" else: - prefix = '~' + prefix = "~" return prefix + self.__name__ def __hash__(self): @@ -1736,7 +1862,7 @@ def __call__(self, *args, **kwargs): # 3.7-3.9 -if not hasattr(typing, 'Concatenate'): +if not hasattr(typing, "Concatenate"): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. class _ConcatenateGenericAlias(list): @@ -1753,8 +1879,10 @@ def __init__(self, origin, args): def __repr__(self): _type_repr = typing._type_repr - return (f'{_type_repr(self.__origin__)}' - f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + return ( + f"{_type_repr(self.__origin__)}" + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]' + ) def __hash__(self): return hash((self.__origin__, self.__args__)) @@ -1766,7 +1894,9 @@ def __call__(self, *args, **kwargs): @property def __parameters__(self): return tuple( - tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + tp + for tp in self.__args__ + if isinstance(tp, (typing.TypeVar, ParamSpec)) ) @@ -1778,19 +1908,21 @@ def _concatenate_getitem(self, parameters): if not isinstance(parameters, tuple): parameters = (parameters,) if not isinstance(parameters[-1], ParamSpec): - raise TypeError("The last parameter to Concatenate should be a " - "ParamSpec variable.") + raise TypeError( + "The last parameter to Concatenate should be a " "ParamSpec variable." + ) msg = "Concatenate[arg, ...]: each arg must be a type." parameters = tuple(typing._type_check(p, msg) for p in parameters) return _ConcatenateGenericAlias(self, parameters) # 3.10+ -if hasattr(typing, 'Concatenate'): +if hasattr(typing, "Concatenate"): Concatenate = typing.Concatenate _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa: F811 # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def Concatenate(self, parameters): """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a @@ -1804,14 +1936,17 @@ def Concatenate(self, parameters): See PEP 612 for detailed information. """ return _concatenate_getitem(self, parameters) + + # 3.7-8 else: + class _ConcatenateForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): return _concatenate_getitem(self, parameters) Concatenate = _ConcatenateForm( - 'Concatenate', + "Concatenate", doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a higher order function which adds, removes or transforms parameters of a callable. @@ -1821,13 +1956,15 @@ def __getitem__(self, parameters): Callable[Concatenate[int, P], int] See PEP 612 for detailed information. - """) + """, + ) # 3.10+ -if hasattr(typing, 'TypeGuard'): +if hasattr(typing, "TypeGuard"): TypeGuard = typing.TypeGuard # 3.9 elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def TypeGuard(self, parameters): """Special typing form used to annotate the return type of a user-defined @@ -1872,18 +2009,22 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). """ - item = typing._type_check(parameters, f'{self} accepts only a single type.') + item = typing._type_check(parameters, f"{self} accepts only a single type.") return typing._GenericAlias(self, (item,)) + + # 3.7-3.8 else: + class _TypeGuardForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type" + ) return typing._GenericAlias(self, (item,)) TypeGuard = _TypeGuardForm( - 'TypeGuard', + "TypeGuard", doc="""Special typing form used to annotate the return type of a user-defined type guard function. ``TypeGuard`` only accepts a single type argument. At runtime, functions marked this way should return a boolean. @@ -1925,12 +2066,13 @@ def is_str(val: Union[str, float]): ``TypeGuard`` also works with type variables. For more information, see PEP 647 (User-Defined Type Guards). - """) + """, + ) # Vendored from cpython typing._SpecialFrom class _SpecialForm(typing._Final, _root=True): - __slots__ = ('_name', '__doc__', '_getitem') + __slots__ = ("_name", "__doc__", "_getitem") def __init__(self, getitem): self._getitem = getitem @@ -1938,7 +2080,7 @@ def __init__(self, getitem): self.__doc__ = getitem.__doc__ def __getattr__(self, item): - if item in {'__name__', '__qualname__'}: + if item in {"__name__", "__qualname__"}: return self._name raise AttributeError(item) @@ -1947,7 +2089,7 @@ def __mro_entries__(self, bases): raise TypeError(f"Cannot subclass {self!r}") def __repr__(self): - return f'typing_extensions.{self._name}' + return f"typing_extensions.{self._name}" def __reduce__(self): return self._name @@ -1975,6 +2117,7 @@ def __getitem__(self, parameters): if hasattr(typing, "LiteralString"): LiteralString = typing.LiteralString else: + @_SpecialForm def LiteralString(self, params): """Represents an arbitrary literal string. @@ -1998,6 +2141,7 @@ def query(sql: LiteralString) -> ...: if hasattr(typing, "Self"): Self = typing.Self else: + @_SpecialForm def Self(self, params): """Used to spell the type of "self" in classes. @@ -2019,6 +2163,7 @@ def parse(self, data: bytes) -> Self: if hasattr(typing, "Never"): Never = typing.Never else: + @_SpecialForm def Never(self, params): """The bottom type, a type that has no members. @@ -2046,10 +2191,11 @@ def int_or_str(arg: int | str) -> None: raise TypeError(f"{self} is not subscriptable") -if hasattr(typing, 'Required'): +if hasattr(typing, "Required"): Required = typing.Required NotRequired = typing.NotRequired elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm def Required(self, parameters): """A special typing construct to mark a key of a total=False TypedDict @@ -2067,7 +2213,9 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) @_ExtensionsSpecialForm @@ -2084,18 +2232,22 @@ class Movie(TypedDict): year=1999, ) """ - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) else: + class _RequiredForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return typing._GenericAlias(self, (item,)) Required = _RequiredForm( - 'Required', + "Required", doc="""A special typing construct to mark a key of a total=False TypedDict as required. For example: @@ -2110,9 +2262,10 @@ class Movie(TypedDict, total=False): There is no runtime checking that a required key is actually provided when instantiating a related TypedDict. - """) + """, + ) NotRequired = _RequiredForm( - 'NotRequired', + "NotRequired", doc="""A special typing construct to mark a key of a TypedDict as potentially missing. For example: @@ -2124,7 +2277,8 @@ class Movie(TypedDict): title='The Matrix', # typechecker error if key is omitted year=1999, ) - """) + """, + ) _UNPACK_DOC = """\ @@ -2176,6 +2330,7 @@ def _is_unpack(obj): return get_origin(obj) is Unpack elif sys.version_info[:2] >= (3, 9): + class _UnpackSpecialForm(_ExtensionsSpecialForm, _root=True): def __init__(self, getitem): super().__init__(getitem) @@ -2186,23 +2341,27 @@ class _UnpackAlias(typing._GenericAlias, _root=True): @_UnpackSpecialForm def Unpack(self, parameters): - item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return _UnpackAlias(self, (item,)) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) else: + class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar class _UnpackForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') + item = typing._type_check( + parameters, f"{self._name} accepts only a single type." + ) return _UnpackAlias(self, (item,)) - Unpack = _UnpackForm('Unpack', doc=_UNPACK_DOC) + Unpack = _UnpackForm("Unpack", doc=_UNPACK_DOC) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) @@ -2226,6 +2385,7 @@ def __init_subclass__(self, *args, **kwds): raise TypeError("Cannot subclass special typing classes") else: + class TypeVarTuple(_DefaultMixin): """Type variable tuple. @@ -2282,7 +2442,7 @@ def __init__(self, name, *, default=_marker): # for pickling: def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod self.__unpacked__ = Unpack[self] @@ -2300,13 +2460,14 @@ def __reduce__(self): return self.__name__ def __init_subclass__(self, *args, **kwds): - if '_root' not in kwds: + if "_root" not in kwds: raise TypeError("Cannot subclass special typing classes") if hasattr(typing, "reveal_type"): reveal_type = typing.reveal_type else: + def reveal_type(__obj: T) -> T: """Reveal the inferred type of a variable. @@ -2330,6 +2491,7 @@ def reveal_type(__obj: T) -> T: if hasattr(typing, "assert_never"): assert_never = typing.assert_never else: + def assert_never(__arg: Never) -> Never: """Assert to the type checker that a line of code is unreachable. @@ -2357,6 +2519,7 @@ def int_or_str(arg: int | str) -> None: # dataclass_transform exists in 3.11 but lacks the frozen_default parameter dataclass_transform = typing.dataclass_transform else: + def dataclass_transform( *, eq_default: bool = True, @@ -2364,8 +2527,7 @@ def dataclass_transform( kw_only_default: bool = False, frozen_default: bool = False, field_specifiers: typing.Tuple[ - typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], - ... + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], ... ] = (), **kwargs: typing.Any, ) -> typing.Callable[[T], T]: @@ -2430,6 +2592,7 @@ class CustomerModel(ModelBase): See PEP 681 for details. """ + def decorator(cls_or_fn): cls_or_fn.__dataclass_transform__ = { "eq_default": eq_default, @@ -2440,6 +2603,7 @@ def decorator(cls_or_fn): "kwargs": kwargs, } return cls_or_fn + return decorator @@ -2533,6 +2697,7 @@ def g(x: str) -> int: ... See PEP 702 for details. """ + def decorator(__arg: _T) -> _T: if category is None: __arg.__deprecated__ = __msg @@ -2556,6 +2721,7 @@ def __new__(cls, *args, **kwargs): __arg.__deprecated__ = __new__.__deprecated__ = __msg return __arg elif callable(__arg): + @functools.wraps(__arg) def wrapper(*args, **kwargs): warnings.warn(__msg, category=category, stacklevel=stacklevel + 1) @@ -2592,12 +2758,14 @@ def wrapper(*args, **kwargs): if sys.version_info >= (3, 13): NamedTuple = typing.NamedTuple else: + def _make_nmtuple(name, types, module, defaults=()): fields = [n for n, t in types] - annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") - for n, t in types} - nm_tpl = collections.namedtuple(name, fields, - defaults=defaults, module=module) + annotations = { + n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types + } + nm_tpl = collections.namedtuple(name, fields, defaults=defaults, module=module) nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations # The `_field_types` attribute was removed in 3.9; # in earlier versions, it is the same as the `__annotations__` attribute @@ -2606,7 +2774,9 @@ def _make_nmtuple(name, types, module, defaults=()): return nm_tpl _prohibited_namedtuple_fields = typing._prohibited - _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + _special_namedtuple_fields = frozenset( + {"__module__", "__name__", "__annotations__"} + ) class _NamedTupleMeta(type): def __new__(cls, typename, bases, ns): @@ -2614,27 +2784,33 @@ def __new__(cls, typename, bases, ns): for base in bases: if base is not _NamedTuple and base is not typing.Generic: raise TypeError( - 'can only inherit from a NamedTuple type and Generic') + "can only inherit from a NamedTuple type and Generic" + ) bases = tuple(tuple if base is _NamedTuple else base for base in bases) - types = ns.get('__annotations__', {}) + types = ns.get("__annotations__", {}) default_names = [] for field_name in types: if field_name in ns: default_names.append(field_name) elif default_names: - raise TypeError(f"Non-default namedtuple field {field_name} " - f"cannot follow default field" - f"{'s' if len(default_names) > 1 else ''} " - f"{', '.join(default_names)}") + raise TypeError( + f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}" + ) nm_tpl = _make_nmtuple( - typename, types.items(), + typename, + types.items(), defaults=[ns[n] for n in default_names], - module=ns['__module__'] + module=ns["__module__"], ) nm_tpl.__bases__ = bases if typing.Generic in bases: - if hasattr(typing, '_generic_class_getitem'): # 3.12+ - nm_tpl.__class_getitem__ = classmethod(typing._generic_class_getitem) + if hasattr(typing, "_generic_class_getitem"): # 3.12+ + nm_tpl.__class_getitem__ = classmethod( + typing._generic_class_getitem + ) else: class_getitem = typing.Generic.__class_getitem__.__func__ nm_tpl.__class_getitem__ = classmethod(class_getitem) @@ -2642,13 +2818,15 @@ def __new__(cls, typename, bases, ns): for key in ns: if key in _prohibited_namedtuple_fields: raise AttributeError("Cannot overwrite NamedTuple attribute " + key) - elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: + elif ( + key not in _special_namedtuple_fields and key not in nm_tpl._fields + ): setattr(nm_tpl, key, ns[key]) if typing.Generic in bases: nm_tpl.__init_subclass__() return nm_tpl - _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + _NamedTuple = type.__new__(_NamedTupleMeta, "NamedTuple", (), {}) def _namedtuple_mro_entries(bases): assert NamedTuple in bases @@ -2686,11 +2864,15 @@ class Employee(NamedTuple): deprecated_thing = "Failing to pass a value for the 'fields' parameter" example = f"`{__typename} = NamedTuple({__typename!r}, [])`" deprecation_msg = ( - "{name} is deprecated and will be disallowed in Python {remove}. " - "To create a NamedTuple class with 0 fields " - "using the functional syntax, " - "pass an empty list, e.g. " - ) + example + "." + ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + + example + + "." + ) elif __fields is None: if kwargs: raise TypeError( @@ -2701,14 +2883,20 @@ class Employee(NamedTuple): deprecated_thing = "Passing `None` as the 'fields' parameter" example = f"`{__typename} = NamedTuple({__typename!r}, [])`" deprecation_msg = ( - "{name} is deprecated and will be disallowed in Python {remove}. " - "To create a NamedTuple class with 0 fields " - "using the functional syntax, " - "pass an empty list, e.g. " - ) + example + "." + ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + + example + + "." + ) elif kwargs: - raise TypeError("Either list of fields or keywords" - " can be provided to NamedTuple, not both") + raise TypeError( + "Either list of fields or keywords" + " can be provided to NamedTuple, not both" + ) if __fields is _marker or __fields is None: warnings.warn( deprecation_msg.format(name=deprecated_thing, remove="3.15"), @@ -2724,7 +2912,7 @@ class Employee(NamedTuple): # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, # so just leave the signature as it is on 3.7. if sys.version_info >= (3, 8): - _new_signature = '(typename, fields=None, /, **kwargs)' + _new_signature = "(typename, fields=None, /, **kwargs)" if isinstance(NamedTuple, _types.FunctionType): NamedTuple.__text_signature__ = _new_signature else: @@ -2734,6 +2922,7 @@ class Employee(NamedTuple): if hasattr(collections.abc, "Buffer"): Buffer = collections.abc.Buffer else: + class Buffer(abc.ABC): """Base class for classes that implement the buffer protocol. @@ -2764,6 +2953,7 @@ class Buffer(abc.ABC): if hasattr(_types, "get_original_bases"): get_original_bases = _types.get_original_bases else: + def get_original_bases(__cls): """Return the class's "original" bases prior to modification by `__mro_entries__`. @@ -2792,7 +2982,7 @@ class Baz(list[str]): ... return __cls.__bases__ except AttributeError: raise TypeError( - f'Expected an instance of type, not {type(__cls).__name__!r}' + f"Expected an instance of type, not {type(__cls).__name__!r}" ) from None @@ -2801,6 +2991,7 @@ class Baz(list[str]): ... if sys.version_info >= (3, 11): NewType = typing.NewType else: + class NewType: """NewType creates simple unique types with almost zero runtime overhead. NewType(name, tp) is considered a subtype of tp @@ -2820,12 +3011,12 @@ def __call__(self, obj): def __init__(self, name, tp): self.__qualname__ = name - if '.' in name: - name = name.rpartition('.')[-1] + if "." in name: + name = name.rpartition(".")[-1] self.__name__ = name self.__supertype__ = tp def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod def __mro_entries__(self, bases): @@ -2845,7 +3036,7 @@ def __init_subclass__(cls): return (Dummy,) def __repr__(self): - return f'{self.__module__}.{self.__qualname__}' + return f"{self.__module__}.{self.__qualname__}" def __reduce__(self): return self.__qualname__ @@ -2864,14 +3055,18 @@ def __ror__(self, other): if hasattr(typing, "TypeAliasType"): TypeAliasType = typing.TypeAliasType else: + def _is_unionable(obj): """Corresponds to is_unionable() in unionobject.c in CPython.""" - return obj is None or isinstance(obj, ( - type, - _types.GenericAlias, - _types.UnionType, - TypeAliasType, - )) + return obj is None or isinstance( + obj, + ( + type, + _types.GenericAlias, + _types.UnionType, + TypeAliasType, + ), + ) class TypeAliasType: """Create named, parameterized type aliases. @@ -2915,7 +3110,7 @@ def __init__(self, name: str, value, *, type_params=()): parameters.append(type_param) self.__parameters__ = tuple(parameters) def_mod = _caller() - if def_mod != 'typing_extensions': + if def_mod != "typing_extensions": self.__module__ = def_mod # Setting this attribute closes the TypeAliasType from further modification self.__name__ = name @@ -2932,7 +3127,12 @@ def _raise_attribute_error(self, name: str) -> Never: # Match the Python 3.12 error messages exactly if name == "__name__": raise AttributeError("readonly attribute") - elif name in {"__value__", "__type_params__", "__parameters__", "__module__"}: + elif name in { + "__value__", + "__type_params__", + "__parameters__", + "__module__", + }: raise AttributeError( f"attribute '{name}' of 'typing.TypeAliasType' objects " "is not writable" @@ -2950,7 +3150,7 @@ def __getitem__(self, parameters): parameters = (parameters,) parameters = [ typing._type_check( - item, f'Subscripting {self.__name__} requires a type.' + item, f"Subscripting {self.__name__} requires a type." ) for item in parameters ] @@ -2970,6 +3170,7 @@ def __call__(self): raise TypeError("Type alias is not callable") if sys.version_info >= (3, 10): + def __or__(self, right): # For forward compatibility with 3.12, reject Unions # that are not accepted by the built-in Union. @@ -2987,6 +3188,7 @@ def __ror__(self, left): is_protocol = typing.is_protocol get_protocol_members = typing.get_protocol_members else: + def is_protocol(__tp: type) -> bool: """Return True if the given type is a Protocol. @@ -3003,7 +3205,7 @@ def is_protocol(__tp: type) -> bool: """ return ( isinstance(__tp, type) - and getattr(__tp, '_is_protocol', False) + and getattr(__tp, "_is_protocol", False) and __tp is not Protocol and __tp is not getattr(typing, "Protocol", object()) ) @@ -3023,8 +3225,8 @@ def get_protocol_members(__tp: type) -> typing.FrozenSet[str]: Raise a TypeError for arguments that are not Protocols. """ if not is_protocol(__tp): - raise TypeError(f'{__tp!r} is not a Protocol') - if hasattr(__tp, '__protocol_attrs__'): + raise TypeError(f"{__tp!r} is not a Protocol") + if hasattr(__tp, "__protocol_attrs__"): return frozenset(__tp.__protocol_attrs__) return frozenset(_get_protocol_attrs(__tp)) diff --git a/metaflow/_vendor/v3_7/zipp.py b/metaflow/_vendor/v3_7/zipp.py index 26b723c1fd3..72632b0b773 100644 --- a/metaflow/_vendor/v3_7/zipp.py +++ b/metaflow/_vendor/v3_7/zipp.py @@ -12,7 +12,7 @@ OrderedDict = dict -__all__ = ['Path'] +__all__ = ["Path"] def _parents(path): @@ -93,7 +93,7 @@ def resolve_dir(self, name): as a directory (with the trailing slash). """ names = self._name_set() - dirname = name + '/' + dirname = name + "/" dir_match = name not in names and dirname in names return dirname if dir_match else name @@ -110,7 +110,7 @@ def make(cls, source): return cls(_pathlib_compat(source)) # Only allow for FastLookup when supplied zipfile is read-only - if 'r' not in source.mode: + if "r" not in source.mode: cls = CompleteDirs source.__class__ = cls @@ -240,7 +240,7 @@ def __init__(self, root, at=""): self.root = FastLookup.make(root) self.at = at - def open(self, mode='r', *args, pwd=None, **kwargs): + def open(self, mode="r", *args, pwd=None, **kwargs): """ Open this entry as text or binary following the semantics of ``pathlib.Path.open()`` by passing arguments through @@ -249,10 +249,10 @@ def open(self, mode='r', *args, pwd=None, **kwargs): if self.is_dir(): raise IsADirectoryError(self) zip_mode = mode[0] - if not self.exists() and zip_mode == 'r': + if not self.exists() and zip_mode == "r": raise FileNotFoundError(self) stream = self.root.open(self.at, zip_mode, pwd=pwd) - if 'b' in mode: + if "b" in mode: if args or kwargs: raise ValueError("encoding args invalid for binary operation") return stream @@ -279,11 +279,11 @@ def filename(self): return pathlib.Path(self.root.filename).joinpath(self.at) def read_text(self, *args, **kwargs): - with self.open('r', *args, **kwargs) as strm: + with self.open("r", *args, **kwargs) as strm: return strm.read() def read_bytes(self): - with self.open('rb') as strm: + with self.open("rb") as strm: return strm.read() def _is_child(self, path): @@ -323,7 +323,7 @@ def joinpath(self, *other): def parent(self): if not self.at: return self.filename.parent - parent_at = posixpath.dirname(self.at.rstrip('/')) + parent_at = posixpath.dirname(self.at.rstrip("/")) if parent_at: - parent_at += '/' + parent_at += "/" return self._next(parent_at) diff --git a/metaflow/_vendor/yaml/__init__.py b/metaflow/_vendor/yaml/__init__.py index 13d687c501c..26d168bae7f 100644 --- a/metaflow/_vendor/yaml/__init__.py +++ b/metaflow/_vendor/yaml/__init__.py @@ -1,4 +1,3 @@ - from .error import * from .tokens import * @@ -8,24 +7,26 @@ from .loader import * from .dumper import * -__version__ = '5.3.1' +__version__ = "5.3.1" try: from .cyaml import * + __with_libyaml__ = True except ImportError: __with_libyaml__ = False import io -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # Warnings control -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # 'Global' warnings state: _warnings_enabled = { - 'YAMLLoadWarning': True, + "YAMLLoadWarning": True, } + # Get or set global warnings' state def warnings(settings=None): if settings is None: @@ -36,12 +37,14 @@ def warnings(settings=None): if key in _warnings_enabled: _warnings_enabled[key] = settings[key] + # Warn when load() is called without Loader=... class YAMLLoadWarning(RuntimeWarning): pass + def load_warning(method): - if _warnings_enabled['YAMLLoadWarning'] is False: + if _warnings_enabled["YAMLLoadWarning"] is False: return import warnings @@ -54,7 +57,8 @@ def load_warning(method): warnings.warn(message, YAMLLoadWarning, stacklevel=3) -#------------------------------------------------------------------------------ + +# ------------------------------------------------------------------------------ def scan(stream, Loader=Loader): """ Scan a YAML stream and produce scanning tokens. @@ -66,6 +70,7 @@ def scan(stream, Loader=Loader): finally: loader.dispose() + def parse(stream, Loader=Loader): """ Parse a YAML stream and produce parsing events. @@ -77,6 +82,7 @@ def parse(stream, Loader=Loader): finally: loader.dispose() + def compose(stream, Loader=Loader): """ Parse the first YAML document in a stream @@ -88,6 +94,7 @@ def compose(stream, Loader=Loader): finally: loader.dispose() + def compose_all(stream, Loader=Loader): """ Parse all YAML documents in a stream @@ -100,13 +107,14 @@ def compose_all(stream, Loader=Loader): finally: loader.dispose() + def load(stream, Loader=None): """ Parse the first YAML document in a stream and produce the corresponding Python object. """ if Loader is None: - load_warning('load') + load_warning("load") Loader = FullLoader loader = Loader(stream) @@ -115,13 +123,14 @@ def load(stream, Loader=None): finally: loader.dispose() + def load_all(stream, Loader=None): """ Parse all YAML documents in a stream and produce corresponding Python objects. """ if Loader is None: - load_warning('load_all') + load_warning("load_all") Loader = FullLoader loader = Loader(stream) @@ -131,6 +140,7 @@ def load_all(stream, Loader=None): finally: loader.dispose() + def full_load(stream): """ Parse the first YAML document in a stream @@ -141,6 +151,7 @@ def full_load(stream): """ return load(stream, FullLoader) + def full_load_all(stream): """ Parse all YAML documents in a stream @@ -151,6 +162,7 @@ def full_load_all(stream): """ return load_all(stream, FullLoader) + def safe_load(stream): """ Parse the first YAML document in a stream @@ -161,6 +173,7 @@ def safe_load(stream): """ return load(stream, SafeLoader) + def safe_load_all(stream): """ Parse all YAML documents in a stream @@ -171,6 +184,7 @@ def safe_load_all(stream): """ return load_all(stream, SafeLoader) + def unsafe_load(stream): """ Parse the first YAML document in a stream @@ -181,6 +195,7 @@ def unsafe_load(stream): """ return load(stream, UnsafeLoader) + def unsafe_load_all(stream): """ Parse all YAML documents in a stream @@ -191,9 +206,17 @@ def unsafe_load_all(stream): """ return load_all(stream, UnsafeLoader) -def emit(events, stream=None, Dumper=Dumper, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None): + +def emit( + events, + stream=None, + Dumper=Dumper, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, +): """ Emit YAML parsing events into a stream. If stream is None, return the produced string instead. @@ -202,8 +225,14 @@ def emit(events, stream=None, Dumper=Dumper, if stream is None: stream = io.StringIO() getvalue = stream.getvalue - dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break) + dumper = Dumper( + stream, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + ) try: for event in events: dumper.emit(event) @@ -212,11 +241,22 @@ def emit(events, stream=None, Dumper=Dumper, if getvalue: return getvalue() -def serialize_all(nodes, stream=None, Dumper=Dumper, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None): + +def serialize_all( + nodes, + stream=None, + Dumper=Dumper, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, +): """ Serialize a sequence of representation trees into a YAML stream. If stream is None, return the produced string instead. @@ -228,10 +268,19 @@ def serialize_all(nodes, stream=None, Dumper=Dumper, else: stream = io.BytesIO() getvalue = stream.getvalue - dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break, - encoding=encoding, version=version, tags=tags, - explicit_start=explicit_start, explicit_end=explicit_end) + dumper = Dumper( + stream, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + encoding=encoding, + version=version, + tags=tags, + explicit_start=explicit_start, + explicit_end=explicit_end, + ) try: dumper.open() for node in nodes: @@ -242,6 +291,7 @@ def serialize_all(nodes, stream=None, Dumper=Dumper, if getvalue: return getvalue() + def serialize(node, stream=None, Dumper=Dumper, **kwds): """ Serialize a representation tree into a YAML stream. @@ -249,12 +299,25 @@ def serialize(node, stream=None, Dumper=Dumper, **kwds): """ return serialize_all([node], stream, Dumper=Dumper, **kwds) -def dump_all(documents, stream=None, Dumper=Dumper, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): + +def dump_all( + documents, + stream=None, + Dumper=Dumper, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, +): """ Serialize a sequence of Python objects into a YAML stream. If stream is None, return the produced string instead. @@ -266,12 +329,22 @@ def dump_all(documents, stream=None, Dumper=Dumper, else: stream = io.BytesIO() getvalue = stream.getvalue - dumper = Dumper(stream, default_style=default_style, - default_flow_style=default_flow_style, - canonical=canonical, indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break, - encoding=encoding, version=version, tags=tags, - explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys) + dumper = Dumper( + stream, + default_style=default_style, + default_flow_style=default_flow_style, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + encoding=encoding, + version=version, + tags=tags, + explicit_start=explicit_start, + explicit_end=explicit_end, + sort_keys=sort_keys, + ) try: dumper.open() for data in documents: @@ -282,6 +355,7 @@ def dump_all(documents, stream=None, Dumper=Dumper, if getvalue: return getvalue() + def dump(data, stream=None, Dumper=Dumper, **kwds): """ Serialize a Python object into a YAML stream. @@ -289,6 +363,7 @@ def dump(data, stream=None, Dumper=Dumper, **kwds): """ return dump_all([data], stream, Dumper=Dumper, **kwds) + def safe_dump_all(documents, stream=None, **kwds): """ Serialize a sequence of Python objects into a YAML stream. @@ -297,6 +372,7 @@ def safe_dump_all(documents, stream=None, **kwds): """ return dump_all(documents, stream, Dumper=SafeDumper, **kwds) + def safe_dump(data, stream=None, **kwds): """ Serialize a Python object into a YAML stream. @@ -305,8 +381,8 @@ def safe_dump(data, stream=None, **kwds): """ return dump_all([data], stream, Dumper=SafeDumper, **kwds) -def add_implicit_resolver(tag, regexp, first=None, - Loader=None, Dumper=Dumper): + +def add_implicit_resolver(tag, regexp, first=None, Loader=None, Dumper=Dumper): """ Add an implicit scalar detector. If an implicit scalar value matches the given regexp, @@ -321,6 +397,7 @@ def add_implicit_resolver(tag, regexp, first=None, Loader.add_implicit_resolver(tag, regexp, first) Dumper.add_implicit_resolver(tag, regexp, first) + def add_path_resolver(tag, path, kind=None, Loader=None, Dumper=Dumper): """ Add a path based resolver for the given tag. @@ -336,6 +413,7 @@ def add_path_resolver(tag, path, kind=None, Loader=None, Dumper=Dumper): Loader.add_path_resolver(tag, path, kind) Dumper.add_path_resolver(tag, path, kind) + def add_constructor(tag, constructor, Loader=None): """ Add a constructor for the given tag. @@ -349,6 +427,7 @@ def add_constructor(tag, constructor, Loader=None): else: Loader.add_constructor(tag, constructor) + def add_multi_constructor(tag_prefix, multi_constructor, Loader=None): """ Add a multi-constructor for the given tag prefix. @@ -363,6 +442,7 @@ def add_multi_constructor(tag_prefix, multi_constructor, Loader=None): else: Loader.add_multi_constructor(tag_prefix, multi_constructor) + def add_representer(data_type, representer, Dumper=Dumper): """ Add a representer for the given type. @@ -372,6 +452,7 @@ def add_representer(data_type, representer, Dumper=Dumper): """ Dumper.add_representer(data_type, representer) + def add_multi_representer(data_type, multi_representer, Dumper=Dumper): """ Add a representer for the given type. @@ -381,13 +462,15 @@ def add_multi_representer(data_type, multi_representer, Dumper=Dumper): """ Dumper.add_multi_representer(data_type, multi_representer) + class YAMLObjectMetaclass(type): """ The metaclass for YAMLObject. """ + def __init__(cls, name, bases, kwds): super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) - if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None: + if "yaml_tag" in kwds and kwds["yaml_tag"] is not None: if isinstance(cls.yaml_loader, list): for loader in cls.yaml_loader: loader.add_constructor(cls.yaml_tag, cls.from_yaml) @@ -396,6 +479,7 @@ def __init__(cls, name, bases, kwds): cls.yaml_dumper.add_representer(cls, cls.to_yaml) + class YAMLObject(metaclass=YAMLObjectMetaclass): """ An object that can dump itself to a YAML stream @@ -422,6 +506,6 @@ def to_yaml(cls, dumper, data): """ Convert a Python object to a representation node. """ - return dumper.represent_yaml_object(cls.yaml_tag, data, cls, - flow_style=cls.yaml_flow_style) - + return dumper.represent_yaml_object( + cls.yaml_tag, data, cls, flow_style=cls.yaml_flow_style + ) diff --git a/metaflow/_vendor/yaml/composer.py b/metaflow/_vendor/yaml/composer.py index 6d15cb40e3b..a46f4849eaf 100644 --- a/metaflow/_vendor/yaml/composer.py +++ b/metaflow/_vendor/yaml/composer.py @@ -1,13 +1,14 @@ - -__all__ = ['Composer', 'ComposerError'] +__all__ = ["Composer", "ComposerError"] from .error import MarkedYAMLError from .events import * from .nodes import * + class ComposerError(MarkedYAMLError): pass + class Composer: def __init__(self): @@ -38,9 +39,12 @@ def get_single_node(self): # Ensure that the stream contains no more documents. if not self.check_event(StreamEndEvent): event = self.get_event() - raise ComposerError("expected a single document in the stream", - document.start_mark, "but found another document", - event.start_mark) + raise ComposerError( + "expected a single document in the stream", + document.start_mark, + "but found another document", + event.start_mark, + ) # Drop the STREAM-END event. self.get_event() @@ -65,16 +69,20 @@ def compose_node(self, parent, index): event = self.get_event() anchor = event.anchor if anchor not in self.anchors: - raise ComposerError(None, None, "found undefined alias %r" - % anchor, event.start_mark) + raise ComposerError( + None, None, "found undefined alias %r" % anchor, event.start_mark + ) return self.anchors[anchor] event = self.peek_event() anchor = event.anchor if anchor is not None: if anchor in self.anchors: - raise ComposerError("found duplicate anchor %r; first occurrence" - % anchor, self.anchors[anchor].start_mark, - "second occurrence", event.start_mark) + raise ComposerError( + "found duplicate anchor %r; first occurrence" % anchor, + self.anchors[anchor].start_mark, + "second occurrence", + event.start_mark, + ) self.descend_resolver(parent, index) if self.check_event(ScalarEvent): node = self.compose_scalar_node(anchor) @@ -88,10 +96,11 @@ def compose_node(self, parent, index): def compose_scalar_node(self, anchor): event = self.get_event() tag = event.tag - if tag is None or tag == '!': + if tag is None or tag == "!": tag = self.resolve(ScalarNode, event.value, event.implicit) - node = ScalarNode(tag, event.value, - event.start_mark, event.end_mark, style=event.style) + node = ScalarNode( + tag, event.value, event.start_mark, event.end_mark, style=event.style + ) if anchor is not None: self.anchors[anchor] = node return node @@ -99,11 +108,11 @@ def compose_scalar_node(self, anchor): def compose_sequence_node(self, anchor): start_event = self.get_event() tag = start_event.tag - if tag is None or tag == '!': + if tag is None or tag == "!": tag = self.resolve(SequenceNode, None, start_event.implicit) - node = SequenceNode(tag, [], - start_event.start_mark, None, - flow_style=start_event.flow_style) + node = SequenceNode( + tag, [], start_event.start_mark, None, flow_style=start_event.flow_style + ) if anchor is not None: self.anchors[anchor] = node index = 0 @@ -117,23 +126,22 @@ def compose_sequence_node(self, anchor): def compose_mapping_node(self, anchor): start_event = self.get_event() tag = start_event.tag - if tag is None or tag == '!': + if tag is None or tag == "!": tag = self.resolve(MappingNode, None, start_event.implicit) - node = MappingNode(tag, [], - start_event.start_mark, None, - flow_style=start_event.flow_style) + node = MappingNode( + tag, [], start_event.start_mark, None, flow_style=start_event.flow_style + ) if anchor is not None: self.anchors[anchor] = node while not self.check_event(MappingEndEvent): - #key_event = self.peek_event() + # key_event = self.peek_event() item_key = self.compose_node(node, None) - #if item_key in node.value: + # if item_key in node.value: # raise ComposerError("while composing a mapping", start_event.start_mark, # "found duplicate key", key_event.start_mark) item_value = self.compose_node(node, item_key) - #node.value[item_key] = item_value + # node.value[item_key] = item_value node.value.append((item_key, item_value)) end_event = self.get_event() node.end_mark = end_event.end_mark return node - diff --git a/metaflow/_vendor/yaml/constructor.py b/metaflow/_vendor/yaml/constructor.py index 1948b125c20..1ce61729c63 100644 --- a/metaflow/_vendor/yaml/constructor.py +++ b/metaflow/_vendor/yaml/constructor.py @@ -1,11 +1,10 @@ - __all__ = [ - 'BaseConstructor', - 'SafeConstructor', - 'FullConstructor', - 'UnsafeConstructor', - 'Constructor', - 'ConstructorError' + "BaseConstructor", + "SafeConstructor", + "FullConstructor", + "UnsafeConstructor", + "Constructor", + "ConstructorError", ] from .error import * @@ -13,9 +12,11 @@ import collections.abc, datetime, base64, binascii, re, sys, types + class ConstructorError(MarkedYAMLError): pass + class BaseConstructor: yaml_constructors = {} @@ -36,8 +37,12 @@ def check_state_key(self, key): object, to prevent user-controlled methods from being called during deserialization""" if self.get_state_keys_blacklist_regexp().match(key): - raise ConstructorError(None, None, - "blacklisted key '%s' in instance state found" % (key,), None) + raise ConstructorError( + None, + None, + "blacklisted key '%s' in instance state found" % (key,), + None, + ) def get_data(self): # Construct and return the next document. @@ -71,8 +76,9 @@ def construct_object(self, node, deep=False): old_deep = self.deep_construct self.deep_construct = True if node in self.recursive_objects: - raise ConstructorError(None, None, - "found unconstructable recursive node", node.start_mark) + raise ConstructorError( + None, None, "found unconstructable recursive node", node.start_mark + ) self.recursive_objects[node] = None constructor = None tag_suffix = None @@ -81,7 +87,7 @@ def construct_object(self, node, deep=False): else: for tag_prefix in self.yaml_multi_constructors: if tag_prefix is not None and node.tag.startswith(tag_prefix): - tag_suffix = node.tag[len(tag_prefix):] + tag_suffix = node.tag[len(tag_prefix) :] constructor = self.yaml_multi_constructors[tag_prefix] break else: @@ -116,39 +122,54 @@ def construct_object(self, node, deep=False): def construct_scalar(self, node): if not isinstance(node, ScalarNode): - raise ConstructorError(None, None, - "expected a scalar node, but found %s" % node.id, - node.start_mark) + raise ConstructorError( + None, + None, + "expected a scalar node, but found %s" % node.id, + node.start_mark, + ) return node.value def construct_sequence(self, node, deep=False): if not isinstance(node, SequenceNode): - raise ConstructorError(None, None, - "expected a sequence node, but found %s" % node.id, - node.start_mark) - return [self.construct_object(child, deep=deep) - for child in node.value] + raise ConstructorError( + None, + None, + "expected a sequence node, but found %s" % node.id, + node.start_mark, + ) + return [self.construct_object(child, deep=deep) for child in node.value] def construct_mapping(self, node, deep=False): if not isinstance(node, MappingNode): - raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, - node.start_mark) + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) if not isinstance(key, collections.abc.Hashable): - raise ConstructorError("while constructing a mapping", node.start_mark, - "found unhashable key", key_node.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found unhashable key", + key_node.start_mark, + ) value = self.construct_object(value_node, deep=deep) mapping[key] = value return mapping def construct_pairs(self, node, deep=False): if not isinstance(node, MappingNode): - raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, - node.start_mark) + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) pairs = [] for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) @@ -158,22 +179,23 @@ def construct_pairs(self, node, deep=False): @classmethod def add_constructor(cls, tag, constructor): - if not 'yaml_constructors' in cls.__dict__: + if not "yaml_constructors" in cls.__dict__: cls.yaml_constructors = cls.yaml_constructors.copy() cls.yaml_constructors[tag] = constructor @classmethod def add_multi_constructor(cls, tag_prefix, multi_constructor): - if not 'yaml_multi_constructors' in cls.__dict__: + if not "yaml_multi_constructors" in cls.__dict__: cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() cls.yaml_multi_constructors[tag_prefix] = multi_constructor + class SafeConstructor(BaseConstructor): def construct_scalar(self, node): if isinstance(node, MappingNode): for key_node, value_node in node.value: - if key_node.tag == 'tag:yaml.org,2002:value': + if key_node.tag == "tag:yaml.org,2002:value": return self.construct_scalar(value_node) return super().construct_scalar(node) @@ -182,7 +204,7 @@ def flatten_mapping(self, node): index = 0 while index < len(node.value): key_node, value_node = node.value[index] - if key_node.tag == 'tag:yaml.org,2002:merge': + if key_node.tag == "tag:yaml.org,2002:merge": del node.value[index] if isinstance(value_node, MappingNode): self.flatten_mapping(value_node) @@ -191,21 +213,28 @@ def flatten_mapping(self, node): submerge = [] for subnode in value_node.value: if not isinstance(subnode, MappingNode): - raise ConstructorError("while constructing a mapping", - node.start_mark, - "expected a mapping for merging, but found %s" - % subnode.id, subnode.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "expected a mapping for merging, but found %s" + % subnode.id, + subnode.start_mark, + ) self.flatten_mapping(subnode) submerge.append(subnode.value) submerge.reverse() for value in submerge: merge.extend(value) else: - raise ConstructorError("while constructing a mapping", node.start_mark, - "expected a mapping or list of mappings for merging, but found %s" - % value_node.id, value_node.start_mark) - elif key_node.tag == 'tag:yaml.org,2002:value': - key_node.tag = 'tag:yaml.org,2002:str' + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "expected a mapping or list of mappings for merging, but found %s" + % value_node.id, + value_node.start_mark, + ) + elif key_node.tag == "tag:yaml.org,2002:value": + key_node.tag = "tag:yaml.org,2002:str" index += 1 else: index += 1 @@ -222,12 +251,12 @@ def construct_yaml_null(self, node): return None bool_values = { - 'yes': True, - 'no': False, - 'true': True, - 'false': False, - 'on': True, - 'off': False, + "yes": True, + "no": False, + "true": True, + "false": False, + "on": True, + "off": False, } def construct_yaml_bool(self, node): @@ -236,79 +265,83 @@ def construct_yaml_bool(self, node): def construct_yaml_int(self, node): value = self.construct_scalar(node) - value = value.replace('_', '') + value = value.replace("_", "") sign = +1 - if value[0] == '-': + if value[0] == "-": sign = -1 - if value[0] in '+-': + if value[0] in "+-": value = value[1:] - if value == '0': + if value == "0": return 0 - elif value.startswith('0b'): - return sign*int(value[2:], 2) - elif value.startswith('0x'): - return sign*int(value[2:], 16) - elif value[0] == '0': - return sign*int(value, 8) - elif ':' in value: - digits = [int(part) for part in value.split(':')] + elif value.startswith("0b"): + return sign * int(value[2:], 2) + elif value.startswith("0x"): + return sign * int(value[2:], 16) + elif value[0] == "0": + return sign * int(value, 8) + elif ":" in value: + digits = [int(part) for part in value.split(":")] digits.reverse() base = 1 value = 0 for digit in digits: - value += digit*base + value += digit * base base *= 60 - return sign*value + return sign * value else: - return sign*int(value) + return sign * int(value) inf_value = 1e300 - while inf_value != inf_value*inf_value: + while inf_value != inf_value * inf_value: inf_value *= inf_value - nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99). + nan_value = -inf_value / inf_value # Trying to make a quiet NaN (like C99). def construct_yaml_float(self, node): value = self.construct_scalar(node) - value = value.replace('_', '').lower() + value = value.replace("_", "").lower() sign = +1 - if value[0] == '-': + if value[0] == "-": sign = -1 - if value[0] in '+-': + if value[0] in "+-": value = value[1:] - if value == '.inf': - return sign*self.inf_value - elif value == '.nan': + if value == ".inf": + return sign * self.inf_value + elif value == ".nan": return self.nan_value - elif ':' in value: - digits = [float(part) for part in value.split(':')] + elif ":" in value: + digits = [float(part) for part in value.split(":")] digits.reverse() base = 1 value = 0.0 for digit in digits: - value += digit*base + value += digit * base base *= 60 - return sign*value + return sign * value else: - return sign*float(value) + return sign * float(value) def construct_yaml_binary(self, node): try: - value = self.construct_scalar(node).encode('ascii') + value = self.construct_scalar(node).encode("ascii") except UnicodeEncodeError as exc: - raise ConstructorError(None, None, - "failed to convert base64 data into ascii: %s" % exc, - node.start_mark) + raise ConstructorError( + None, + None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark, + ) try: - if hasattr(base64, 'decodebytes'): + if hasattr(base64, "decodebytes"): return base64.decodebytes(value) else: return base64.decodestring(value) except binascii.Error as exc: - raise ConstructorError(None, None, - "failed to decode base64 data: %s" % exc, node.start_mark) + raise ConstructorError( + None, None, "failed to decode base64 data: %s" % exc, node.start_mark + ) timestamp_regexp = re.compile( - r'''^(?P[0-9][0-9][0-9][0-9]) + r"""^(?P[0-9][0-9][0-9][0-9]) -(?P[0-9][0-9]?) -(?P[0-9][0-9]?) (?:(?:[Tt]|[ \t]+) @@ -317,38 +350,41 @@ def construct_yaml_binary(self, node): :(?P[0-9][0-9]) (?:\.(?P[0-9]*))? (?:[ \t]*(?PZ|(?P[-+])(?P[0-9][0-9]?) - (?::(?P[0-9][0-9]))?))?)?$''', re.X) + (?::(?P[0-9][0-9]))?))?)?$""", + re.X, + ) def construct_yaml_timestamp(self, node): value = self.construct_scalar(node) match = self.timestamp_regexp.match(node.value) values = match.groupdict() - year = int(values['year']) - month = int(values['month']) - day = int(values['day']) - if not values['hour']: + year = int(values["year"]) + month = int(values["month"]) + day = int(values["day"]) + if not values["hour"]: return datetime.date(year, month, day) - hour = int(values['hour']) - minute = int(values['minute']) - second = int(values['second']) + hour = int(values["hour"]) + minute = int(values["minute"]) + second = int(values["second"]) fraction = 0 tzinfo = None - if values['fraction']: - fraction = values['fraction'][:6] + if values["fraction"]: + fraction = values["fraction"][:6] while len(fraction) < 6: - fraction += '0' + fraction += "0" fraction = int(fraction) - if values['tz_sign']: - tz_hour = int(values['tz_hour']) - tz_minute = int(values['tz_minute'] or 0) + if values["tz_sign"]: + tz_hour = int(values["tz_hour"]) + tz_minute = int(values["tz_minute"] or 0) delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) - if values['tz_sign'] == '-': + if values["tz_sign"] == "-": delta = -delta tzinfo = datetime.timezone(delta) - elif values['tz']: + elif values["tz"]: tzinfo = datetime.timezone.utc - return datetime.datetime(year, month, day, hour, minute, second, fraction, - tzinfo=tzinfo) + return datetime.datetime( + year, month, day, hour, minute, second, fraction, tzinfo=tzinfo + ) def construct_yaml_omap(self, node): # Note: we do not check for duplicate keys, because it's too @@ -356,17 +392,28 @@ def construct_yaml_omap(self, node): omap = [] yield omap if not isinstance(node, SequenceNode): - raise ConstructorError("while constructing an ordered map", node.start_mark, - "expected a sequence, but found %s" % node.id, node.start_mark) + raise ConstructorError( + "while constructing an ordered map", + node.start_mark, + "expected a sequence, but found %s" % node.id, + node.start_mark, + ) for subnode in node.value: if not isinstance(subnode, MappingNode): - raise ConstructorError("while constructing an ordered map", node.start_mark, - "expected a mapping of length 1, but found %s" % subnode.id, - subnode.start_mark) + raise ConstructorError( + "while constructing an ordered map", + node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark, + ) if len(subnode.value) != 1: - raise ConstructorError("while constructing an ordered map", node.start_mark, - "expected a single mapping item, but found %d items" % len(subnode.value), - subnode.start_mark) + raise ConstructorError( + "while constructing an ordered map", + node.start_mark, + "expected a single mapping item, but found %d items" + % len(subnode.value), + subnode.start_mark, + ) key_node, value_node = subnode.value[0] key = self.construct_object(key_node) value = self.construct_object(value_node) @@ -377,17 +424,28 @@ def construct_yaml_pairs(self, node): pairs = [] yield pairs if not isinstance(node, SequenceNode): - raise ConstructorError("while constructing pairs", node.start_mark, - "expected a sequence, but found %s" % node.id, node.start_mark) + raise ConstructorError( + "while constructing pairs", + node.start_mark, + "expected a sequence, but found %s" % node.id, + node.start_mark, + ) for subnode in node.value: if not isinstance(subnode, MappingNode): - raise ConstructorError("while constructing pairs", node.start_mark, - "expected a mapping of length 1, but found %s" % subnode.id, - subnode.start_mark) + raise ConstructorError( + "while constructing pairs", + node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark, + ) if len(subnode.value) != 1: - raise ConstructorError("while constructing pairs", node.start_mark, - "expected a single mapping item, but found %d items" % len(subnode.value), - subnode.start_mark) + raise ConstructorError( + "while constructing pairs", + node.start_mark, + "expected a single mapping item, but found %d items" + % len(subnode.value), + subnode.start_mark, + ) key_node, value_node = subnode.value[0] key = self.construct_object(key_node) value = self.construct_object(value_node) @@ -416,7 +474,7 @@ def construct_yaml_map(self, node): def construct_yaml_object(self, node, cls): data = cls.__new__(cls) yield data - if hasattr(data, '__setstate__'): + if hasattr(data, "__setstate__"): state = self.construct_mapping(node, deep=True) data.__setstate__(state) else: @@ -424,71 +482,77 @@ def construct_yaml_object(self, node, cls): data.__dict__.update(state) def construct_undefined(self, node): - raise ConstructorError(None, None, - "could not determine a constructor for the tag %r" % node.tag, - node.start_mark) + raise ConstructorError( + None, + None, + "could not determine a constructor for the tag %r" % node.tag, + node.start_mark, + ) + SafeConstructor.add_constructor( - 'tag:yaml.org,2002:null', - SafeConstructor.construct_yaml_null) + "tag:yaml.org,2002:null", SafeConstructor.construct_yaml_null +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:bool', - SafeConstructor.construct_yaml_bool) + "tag:yaml.org,2002:bool", SafeConstructor.construct_yaml_bool +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:int', - SafeConstructor.construct_yaml_int) + "tag:yaml.org,2002:int", SafeConstructor.construct_yaml_int +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:float', - SafeConstructor.construct_yaml_float) + "tag:yaml.org,2002:float", SafeConstructor.construct_yaml_float +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:binary', - SafeConstructor.construct_yaml_binary) + "tag:yaml.org,2002:binary", SafeConstructor.construct_yaml_binary +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:timestamp', - SafeConstructor.construct_yaml_timestamp) + "tag:yaml.org,2002:timestamp", SafeConstructor.construct_yaml_timestamp +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:omap', - SafeConstructor.construct_yaml_omap) + "tag:yaml.org,2002:omap", SafeConstructor.construct_yaml_omap +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:pairs', - SafeConstructor.construct_yaml_pairs) + "tag:yaml.org,2002:pairs", SafeConstructor.construct_yaml_pairs +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:set', - SafeConstructor.construct_yaml_set) + "tag:yaml.org,2002:set", SafeConstructor.construct_yaml_set +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:str', - SafeConstructor.construct_yaml_str) + "tag:yaml.org,2002:str", SafeConstructor.construct_yaml_str +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:seq', - SafeConstructor.construct_yaml_seq) + "tag:yaml.org,2002:seq", SafeConstructor.construct_yaml_seq +) SafeConstructor.add_constructor( - 'tag:yaml.org,2002:map', - SafeConstructor.construct_yaml_map) + "tag:yaml.org,2002:map", SafeConstructor.construct_yaml_map +) + +SafeConstructor.add_constructor(None, SafeConstructor.construct_undefined) -SafeConstructor.add_constructor(None, - SafeConstructor.construct_undefined) class FullConstructor(SafeConstructor): # 'extend' is blacklisted because it is used by # construct_python_object_apply to add `listitems` to a newly generate # python instance def get_state_keys_blacklist(self): - return ['^extend$', '^__.*__$'] + return ["^extend$", "^__.*__$"] def get_state_keys_blacklist_regexp(self): - if not hasattr(self, 'state_keys_blacklist_regexp'): - self.state_keys_blacklist_regexp = re.compile('(' + '|'.join(self.get_state_keys_blacklist()) + ')') + if not hasattr(self, "state_keys_blacklist_regexp"): + self.state_keys_blacklist_regexp = re.compile( + "(" + "|".join(self.get_state_keys_blacklist()) + ")" + ) return self.state_keys_blacklist_regexp def construct_python_str(self, node): @@ -499,107 +563,150 @@ def construct_python_unicode(self, node): def construct_python_bytes(self, node): try: - value = self.construct_scalar(node).encode('ascii') + value = self.construct_scalar(node).encode("ascii") except UnicodeEncodeError as exc: - raise ConstructorError(None, None, - "failed to convert base64 data into ascii: %s" % exc, - node.start_mark) + raise ConstructorError( + None, + None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark, + ) try: - if hasattr(base64, 'decodebytes'): + if hasattr(base64, "decodebytes"): return base64.decodebytes(value) else: return base64.decodestring(value) except binascii.Error as exc: - raise ConstructorError(None, None, - "failed to decode base64 data: %s" % exc, node.start_mark) + raise ConstructorError( + None, None, "failed to decode base64 data: %s" % exc, node.start_mark + ) def construct_python_long(self, node): return self.construct_yaml_int(node) def construct_python_complex(self, node): - return complex(self.construct_scalar(node)) + return complex(self.construct_scalar(node)) def construct_python_tuple(self, node): return tuple(self.construct_sequence(node)) def find_python_module(self, name, mark, unsafe=False): if not name: - raise ConstructorError("while constructing a Python module", mark, - "expected non-empty name appended to the tag", mark) + raise ConstructorError( + "while constructing a Python module", + mark, + "expected non-empty name appended to the tag", + mark, + ) if unsafe: try: __import__(name) except ImportError as exc: - raise ConstructorError("while constructing a Python module", mark, - "cannot find module %r (%s)" % (name, exc), mark) + raise ConstructorError( + "while constructing a Python module", + mark, + "cannot find module %r (%s)" % (name, exc), + mark, + ) if name not in sys.modules: - raise ConstructorError("while constructing a Python module", mark, - "module %r is not imported" % name, mark) + raise ConstructorError( + "while constructing a Python module", + mark, + "module %r is not imported" % name, + mark, + ) return sys.modules[name] def find_python_name(self, name, mark, unsafe=False): if not name: - raise ConstructorError("while constructing a Python object", mark, - "expected non-empty name appended to the tag", mark) - if '.' in name: - module_name, object_name = name.rsplit('.', 1) + raise ConstructorError( + "while constructing a Python object", + mark, + "expected non-empty name appended to the tag", + mark, + ) + if "." in name: + module_name, object_name = name.rsplit(".", 1) else: - module_name = 'builtins' + module_name = "builtins" object_name = name if unsafe: try: __import__(module_name) except ImportError as exc: - raise ConstructorError("while constructing a Python object", mark, - "cannot find module %r (%s)" % (module_name, exc), mark) + raise ConstructorError( + "while constructing a Python object", + mark, + "cannot find module %r (%s)" % (module_name, exc), + mark, + ) if module_name not in sys.modules: - raise ConstructorError("while constructing a Python object", mark, - "module %r is not imported" % module_name, mark) + raise ConstructorError( + "while constructing a Python object", + mark, + "module %r is not imported" % module_name, + mark, + ) module = sys.modules[module_name] if not hasattr(module, object_name): - raise ConstructorError("while constructing a Python object", mark, - "cannot find %r in the module %r" - % (object_name, module.__name__), mark) + raise ConstructorError( + "while constructing a Python object", + mark, + "cannot find %r in the module %r" % (object_name, module.__name__), + mark, + ) return getattr(module, object_name) def construct_python_name(self, suffix, node): value = self.construct_scalar(node) if value: - raise ConstructorError("while constructing a Python name", node.start_mark, - "expected the empty value, but found %r" % value, node.start_mark) + raise ConstructorError( + "while constructing a Python name", + node.start_mark, + "expected the empty value, but found %r" % value, + node.start_mark, + ) return self.find_python_name(suffix, node.start_mark) def construct_python_module(self, suffix, node): value = self.construct_scalar(node) if value: - raise ConstructorError("while constructing a Python module", node.start_mark, - "expected the empty value, but found %r" % value, node.start_mark) + raise ConstructorError( + "while constructing a Python module", + node.start_mark, + "expected the empty value, but found %r" % value, + node.start_mark, + ) return self.find_python_module(suffix, node.start_mark) - def make_python_instance(self, suffix, node, - args=None, kwds=None, newobj=False, unsafe=False): + def make_python_instance( + self, suffix, node, args=None, kwds=None, newobj=False, unsafe=False + ): if not args: args = [] if not kwds: kwds = {} cls = self.find_python_name(suffix, node.start_mark) if not (unsafe or isinstance(cls, type)): - raise ConstructorError("while constructing a Python instance", node.start_mark, - "expected a class, but found %r" % type(cls), - node.start_mark) + raise ConstructorError( + "while constructing a Python instance", + node.start_mark, + "expected a class, but found %r" % type(cls), + node.start_mark, + ) if newobj and isinstance(cls, type): return cls.__new__(cls, *args, **kwds) else: return cls(*args, **kwds) def set_python_instance_state(self, instance, state, unsafe=False): - if hasattr(instance, '__setstate__'): + if hasattr(instance, "__setstate__"): instance.__setstate__(state) else: slotstate = {} if isinstance(state, tuple) and len(state) == 2: state, slotstate = state - if hasattr(instance, '__dict__'): + if hasattr(instance, "__dict__"): if not unsafe and state: for key in state.keys(): self.check_state_key(key) @@ -616,7 +723,7 @@ def construct_python_object(self, suffix, node): # !!python/object:module.name { ... state ... } instance = self.make_python_instance(suffix, node, newobj=True) yield instance - deep = hasattr(instance, '__setstate__') + deep = hasattr(instance, "__setstate__") state = self.construct_mapping(node, deep=deep) self.set_python_instance_state(instance, state) @@ -640,11 +747,11 @@ def construct_python_object_apply(self, suffix, node, newobj=False): dictitems = {} else: value = self.construct_mapping(node, deep=True) - args = value.get('args', []) - kwds = value.get('kwds', {}) - state = value.get('state', {}) - listitems = value.get('listitems', []) - dictitems = value.get('dictitems', {}) + args = value.get("args", []) + kwds = value.get("kwds", {}) + state = value.get("state", {}) + listitems = value.get("listitems", []) + dictitems = value.get("dictitems", {}) instance = self.make_python_instance(suffix, node, args, kwds, newobj) if state: self.set_python_instance_state(instance, state) @@ -658,89 +765,98 @@ def construct_python_object_apply(self, suffix, node, newobj=False): def construct_python_object_new(self, suffix, node): return self.construct_python_object_apply(suffix, node, newobj=True) + FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/none', - FullConstructor.construct_yaml_null) + "tag:yaml.org,2002:python/none", FullConstructor.construct_yaml_null +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/bool', - FullConstructor.construct_yaml_bool) + "tag:yaml.org,2002:python/bool", FullConstructor.construct_yaml_bool +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/str', - FullConstructor.construct_python_str) + "tag:yaml.org,2002:python/str", FullConstructor.construct_python_str +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/unicode', - FullConstructor.construct_python_unicode) + "tag:yaml.org,2002:python/unicode", FullConstructor.construct_python_unicode +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/bytes', - FullConstructor.construct_python_bytes) + "tag:yaml.org,2002:python/bytes", FullConstructor.construct_python_bytes +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/int', - FullConstructor.construct_yaml_int) + "tag:yaml.org,2002:python/int", FullConstructor.construct_yaml_int +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/long', - FullConstructor.construct_python_long) + "tag:yaml.org,2002:python/long", FullConstructor.construct_python_long +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/float', - FullConstructor.construct_yaml_float) + "tag:yaml.org,2002:python/float", FullConstructor.construct_yaml_float +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/complex', - FullConstructor.construct_python_complex) + "tag:yaml.org,2002:python/complex", FullConstructor.construct_python_complex +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/list', - FullConstructor.construct_yaml_seq) + "tag:yaml.org,2002:python/list", FullConstructor.construct_yaml_seq +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/tuple', - FullConstructor.construct_python_tuple) + "tag:yaml.org,2002:python/tuple", FullConstructor.construct_python_tuple +) FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/dict', - FullConstructor.construct_yaml_map) + "tag:yaml.org,2002:python/dict", FullConstructor.construct_yaml_map +) FullConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/name:', - FullConstructor.construct_python_name) + "tag:yaml.org,2002:python/name:", FullConstructor.construct_python_name +) FullConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/module:', - FullConstructor.construct_python_module) + "tag:yaml.org,2002:python/module:", FullConstructor.construct_python_module +) FullConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object:', - FullConstructor.construct_python_object) + "tag:yaml.org,2002:python/object:", FullConstructor.construct_python_object +) FullConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object/new:', - FullConstructor.construct_python_object_new) + "tag:yaml.org,2002:python/object/new:", FullConstructor.construct_python_object_new +) + class UnsafeConstructor(FullConstructor): def find_python_module(self, name, mark): - return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True) + return super(UnsafeConstructor, self).find_python_module( + name, mark, unsafe=True + ) def find_python_name(self, name, mark): return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True) def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): return super(UnsafeConstructor, self).make_python_instance( - suffix, node, args, kwds, newobj, unsafe=True) + suffix, node, args, kwds, newobj, unsafe=True + ) def set_python_instance_state(self, instance, state): return super(UnsafeConstructor, self).set_python_instance_state( - instance, state, unsafe=True) + instance, state, unsafe=True + ) + UnsafeConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object/apply:', - UnsafeConstructor.construct_python_object_apply) + "tag:yaml.org,2002:python/object/apply:", + UnsafeConstructor.construct_python_object_apply, +) + # Constructor is same as UnsafeConstructor. Need to leave this in place in case # people have extended it directly. diff --git a/metaflow/_vendor/yaml/cyaml.py b/metaflow/_vendor/yaml/cyaml.py index 1e606c74b94..3c436146606 100644 --- a/metaflow/_vendor/yaml/cyaml.py +++ b/metaflow/_vendor/yaml/cyaml.py @@ -1,7 +1,12 @@ - __all__ = [ - 'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader', - 'CBaseDumper', 'CSafeDumper', 'CDumper' + "CBaseLoader", + "CSafeLoader", + "CFullLoader", + "CUnsafeLoader", + "CLoader", + "CBaseDumper", + "CSafeDumper", + "CDumper", ] from _yaml import CParser, CEmitter @@ -13,6 +18,7 @@ from .resolver import * + class CBaseLoader(CParser, BaseConstructor, BaseResolver): def __init__(self, stream): @@ -20,6 +26,7 @@ def __init__(self, stream): BaseConstructor.__init__(self) BaseResolver.__init__(self) + class CSafeLoader(CParser, SafeConstructor, Resolver): def __init__(self, stream): @@ -27,6 +34,7 @@ def __init__(self, stream): SafeConstructor.__init__(self) Resolver.__init__(self) + class CFullLoader(CParser, FullConstructor, Resolver): def __init__(self, stream): @@ -34,6 +42,7 @@ def __init__(self, stream): FullConstructor.__init__(self) Resolver.__init__(self) + class CUnsafeLoader(CParser, UnsafeConstructor, Resolver): def __init__(self, stream): @@ -41,6 +50,7 @@ def __init__(self, stream): UnsafeConstructor.__init__(self) Resolver.__init__(self) + class CLoader(CParser, Constructor, Resolver): def __init__(self, stream): @@ -48,54 +58,128 @@ def __init__(self, stream): Constructor.__init__(self) Resolver.__init__(self) + class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - CEmitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, encoding=encoding, - allow_unicode=allow_unicode, line_break=line_break, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + CEmitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + encoding=encoding, + allow_unicode=allow_unicode, + line_break=line_break, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + Representer.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) + class CSafeDumper(CEmitter, SafeRepresenter, Resolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - CEmitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, encoding=encoding, - allow_unicode=allow_unicode, line_break=line_break, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - SafeRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + CEmitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + encoding=encoding, + allow_unicode=allow_unicode, + line_break=line_break, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + SafeRepresenter.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) + class CDumper(CEmitter, Serializer, Representer, Resolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - CEmitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, encoding=encoding, - allow_unicode=allow_unicode, line_break=line_break, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + CEmitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + encoding=encoding, + allow_unicode=allow_unicode, + line_break=line_break, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + Representer.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) - diff --git a/metaflow/_vendor/yaml/dumper.py b/metaflow/_vendor/yaml/dumper.py index 6aadba551f3..02e1b30f793 100644 --- a/metaflow/_vendor/yaml/dumper.py +++ b/metaflow/_vendor/yaml/dumper.py @@ -1,62 +1,141 @@ - -__all__ = ['BaseDumper', 'SafeDumper', 'Dumper'] +__all__ = ["BaseDumper", "SafeDumper", "Dumper"] from .emitter import * from .serializer import * from .representer import * from .resolver import * + class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - Emitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break) - Serializer.__init__(self, encoding=encoding, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + Emitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + ) + Serializer.__init__( + self, + encoding=encoding, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + Representer.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) + class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - Emitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break) - Serializer.__init__(self, encoding=encoding, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - SafeRepresenter.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + Emitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + ) + Serializer.__init__( + self, + encoding=encoding, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + SafeRepresenter.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) + class Dumper(Emitter, Serializer, Representer, Resolver): - def __init__(self, stream, - default_style=None, default_flow_style=False, - canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None, - encoding=None, explicit_start=None, explicit_end=None, - version=None, tags=None, sort_keys=True): - Emitter.__init__(self, stream, canonical=canonical, - indent=indent, width=width, - allow_unicode=allow_unicode, line_break=line_break) - Serializer.__init__(self, encoding=encoding, - explicit_start=explicit_start, explicit_end=explicit_end, - version=version, tags=tags) - Representer.__init__(self, default_style=default_style, - default_flow_style=default_flow_style, sort_keys=sort_keys) + def __init__( + self, + stream, + default_style=None, + default_flow_style=False, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + sort_keys=True, + ): + Emitter.__init__( + self, + stream, + canonical=canonical, + indent=indent, + width=width, + allow_unicode=allow_unicode, + line_break=line_break, + ) + Serializer.__init__( + self, + encoding=encoding, + explicit_start=explicit_start, + explicit_end=explicit_end, + version=version, + tags=tags, + ) + Representer.__init__( + self, + default_style=default_style, + default_flow_style=default_flow_style, + sort_keys=sort_keys, + ) Resolver.__init__(self) - diff --git a/metaflow/_vendor/yaml/emitter.py b/metaflow/_vendor/yaml/emitter.py index a664d011162..4c4f5df7bd5 100644 --- a/metaflow/_vendor/yaml/emitter.py +++ b/metaflow/_vendor/yaml/emitter.py @@ -1,4 +1,3 @@ - # Emitter expects events obeying the following grammar: # stream ::= STREAM-START document* STREAM-END # document ::= DOCUMENT-START node DOCUMENT-END @@ -6,19 +5,28 @@ # sequence ::= SEQUENCE-START node* SEQUENCE-END # mapping ::= MAPPING-START (node node)* MAPPING-END -__all__ = ['Emitter', 'EmitterError'] +__all__ = ["Emitter", "EmitterError"] from .error import YAMLError from .events import * + class EmitterError(YAMLError): pass + class ScalarAnalysis: - def __init__(self, scalar, empty, multiline, - allow_flow_plain, allow_block_plain, - allow_single_quoted, allow_double_quoted, - allow_block): + def __init__( + self, + scalar, + empty, + multiline, + allow_flow_plain, + allow_block_plain, + allow_single_quoted, + allow_double_quoted, + allow_block, + ): self.scalar = scalar self.empty = empty self.multiline = multiline @@ -28,15 +36,23 @@ def __init__(self, scalar, empty, multiline, self.allow_double_quoted = allow_double_quoted self.allow_block = allow_block + class Emitter: DEFAULT_TAG_PREFIXES = { - '!' : '!', - 'tag:yaml.org,2002:' : '!!', + "!": "!", + "tag:yaml.org,2002:": "!!", } - def __init__(self, stream, canonical=None, indent=None, width=None, - allow_unicode=None, line_break=None): + def __init__( + self, + stream, + canonical=None, + indent=None, + width=None, + allow_unicode=None, + line_break=None, + ): # The stream should have the methods `write` and possibly `flush`. self.stream = stream @@ -86,10 +102,10 @@ def __init__(self, stream, canonical=None, indent=None, width=None, if indent and 1 < indent < 10: self.best_indent = indent self.best_width = 80 - if width and width > self.best_indent*2: + if width and width > self.best_indent * 2: self.best_width = width - self.best_line_break = '\n' - if line_break in ['\r', '\n', '\r\n']: + self.best_line_break = "\n" + if line_break in ["\r", "\n", "\r\n"]: self.best_line_break = line_break # Tag prefixes. @@ -141,7 +157,7 @@ def need_events(self, count): level = -1 if level < 0: return False - return (len(self.events) < count+1) + return len(self.events) < count + 1 def increase_indent(self, flow=False, indentless=False): self.indents.append(self.indent) @@ -159,13 +175,12 @@ def increase_indent(self, flow=False, indentless=False): def expect_stream_start(self): if isinstance(self.event, StreamStartEvent): - if self.event.encoding and not hasattr(self.stream, 'encoding'): + if self.event.encoding and not hasattr(self.stream, "encoding"): self.encoding = self.event.encoding self.write_stream_start() self.state = self.expect_first_document_start else: - raise EmitterError("expected StreamStartEvent, but got %s" - % self.event) + raise EmitterError("expected StreamStartEvent, but got %s" % self.event) def expect_nothing(self): raise EmitterError("expected nothing, but got %s" % self.event) @@ -178,7 +193,7 @@ def expect_first_document_start(self): def expect_document_start(self, first=False): if isinstance(self.event, DocumentStartEvent): if (self.event.version or self.event.tags) and self.open_ended: - self.write_indicator('...', True) + self.write_indicator("...", True) self.write_indent() if self.event.version: version_text = self.prepare_version(self.event.version) @@ -192,36 +207,39 @@ def expect_document_start(self, first=False): handle_text = self.prepare_tag_handle(handle) prefix_text = self.prepare_tag_prefix(prefix) self.write_tag_directive(handle_text, prefix_text) - implicit = (first and not self.event.explicit and not self.canonical - and not self.event.version and not self.event.tags - and not self.check_empty_document()) + implicit = ( + first + and not self.event.explicit + and not self.canonical + and not self.event.version + and not self.event.tags + and not self.check_empty_document() + ) if not implicit: self.write_indent() - self.write_indicator('---', True) + self.write_indicator("---", True) if self.canonical: self.write_indent() self.state = self.expect_document_root elif isinstance(self.event, StreamEndEvent): if self.open_ended: - self.write_indicator('...', True) + self.write_indicator("...", True) self.write_indent() self.write_stream_end() self.state = self.expect_nothing else: - raise EmitterError("expected DocumentStartEvent, but got %s" - % self.event) + raise EmitterError("expected DocumentStartEvent, but got %s" % self.event) def expect_document_end(self): if isinstance(self.event, DocumentEndEvent): self.write_indent() if self.event.explicit: - self.write_indicator('...', True) + self.write_indicator("...", True) self.write_indent() self.flush_stream() self.state = self.expect_document_start else: - raise EmitterError("expected DocumentEndEvent, but got %s" - % self.event) + raise EmitterError("expected DocumentEndEvent, but got %s" % self.event) def expect_document_root(self): self.states.append(self.expect_document_end) @@ -229,8 +247,7 @@ def expect_document_root(self): # Node handlers. - def expect_node(self, root=False, sequence=False, mapping=False, - simple_key=False): + def expect_node(self, root=False, sequence=False, mapping=False, simple_key=False): self.root_context = root self.sequence_context = sequence self.mapping_context = mapping @@ -238,19 +255,27 @@ def expect_node(self, root=False, sequence=False, mapping=False, if isinstance(self.event, AliasEvent): self.expect_alias() elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)): - self.process_anchor('&') + self.process_anchor("&") self.process_tag() if isinstance(self.event, ScalarEvent): self.expect_scalar() elif isinstance(self.event, SequenceStartEvent): - if self.flow_level or self.canonical or self.event.flow_style \ - or self.check_empty_sequence(): + if ( + self.flow_level + or self.canonical + or self.event.flow_style + or self.check_empty_sequence() + ): self.expect_flow_sequence() else: self.expect_block_sequence() elif isinstance(self.event, MappingStartEvent): - if self.flow_level or self.canonical or self.event.flow_style \ - or self.check_empty_mapping(): + if ( + self.flow_level + or self.canonical + or self.event.flow_style + or self.check_empty_mapping() + ): self.expect_flow_mapping() else: self.expect_block_mapping() @@ -260,7 +285,7 @@ def expect_node(self, root=False, sequence=False, mapping=False, def expect_alias(self): if self.event.anchor is None: raise EmitterError("anchor is not specified for alias") - self.process_anchor('*') + self.process_anchor("*") self.state = self.states.pop() def expect_scalar(self): @@ -272,7 +297,7 @@ def expect_scalar(self): # Flow sequence handlers. def expect_flow_sequence(self): - self.write_indicator('[', True, whitespace=True) + self.write_indicator("[", True, whitespace=True) self.flow_level += 1 self.increase_indent(flow=True) self.state = self.expect_first_flow_sequence_item @@ -281,7 +306,7 @@ def expect_first_flow_sequence_item(self): if isinstance(self.event, SequenceEndEvent): self.indent = self.indents.pop() self.flow_level -= 1 - self.write_indicator(']', False) + self.write_indicator("]", False) self.state = self.states.pop() else: if self.canonical or self.column > self.best_width: @@ -294,12 +319,12 @@ def expect_flow_sequence_item(self): self.indent = self.indents.pop() self.flow_level -= 1 if self.canonical: - self.write_indicator(',', False) + self.write_indicator(",", False) self.write_indent() - self.write_indicator(']', False) + self.write_indicator("]", False) self.state = self.states.pop() else: - self.write_indicator(',', False) + self.write_indicator(",", False) if self.canonical or self.column > self.best_width: self.write_indent() self.states.append(self.expect_flow_sequence_item) @@ -308,7 +333,7 @@ def expect_flow_sequence_item(self): # Flow mapping handlers. def expect_flow_mapping(self): - self.write_indicator('{', True, whitespace=True) + self.write_indicator("{", True, whitespace=True) self.flow_level += 1 self.increase_indent(flow=True) self.state = self.expect_first_flow_mapping_key @@ -317,7 +342,7 @@ def expect_first_flow_mapping_key(self): if isinstance(self.event, MappingEndEvent): self.indent = self.indents.pop() self.flow_level -= 1 - self.write_indicator('}', False) + self.write_indicator("}", False) self.state = self.states.pop() else: if self.canonical or self.column > self.best_width: @@ -326,7 +351,7 @@ def expect_first_flow_mapping_key(self): self.states.append(self.expect_flow_mapping_simple_value) self.expect_node(mapping=True, simple_key=True) else: - self.write_indicator('?', True) + self.write_indicator("?", True) self.states.append(self.expect_flow_mapping_value) self.expect_node(mapping=True) @@ -335,38 +360,38 @@ def expect_flow_mapping_key(self): self.indent = self.indents.pop() self.flow_level -= 1 if self.canonical: - self.write_indicator(',', False) + self.write_indicator(",", False) self.write_indent() - self.write_indicator('}', False) + self.write_indicator("}", False) self.state = self.states.pop() else: - self.write_indicator(',', False) + self.write_indicator(",", False) if self.canonical or self.column > self.best_width: self.write_indent() if not self.canonical and self.check_simple_key(): self.states.append(self.expect_flow_mapping_simple_value) self.expect_node(mapping=True, simple_key=True) else: - self.write_indicator('?', True) + self.write_indicator("?", True) self.states.append(self.expect_flow_mapping_value) self.expect_node(mapping=True) def expect_flow_mapping_simple_value(self): - self.write_indicator(':', False) + self.write_indicator(":", False) self.states.append(self.expect_flow_mapping_key) self.expect_node(mapping=True) def expect_flow_mapping_value(self): if self.canonical or self.column > self.best_width: self.write_indent() - self.write_indicator(':', True) + self.write_indicator(":", True) self.states.append(self.expect_flow_mapping_key) self.expect_node(mapping=True) # Block sequence handlers. def expect_block_sequence(self): - indentless = (self.mapping_context and not self.indention) + indentless = self.mapping_context and not self.indention self.increase_indent(flow=False, indentless=indentless) self.state = self.expect_first_block_sequence_item @@ -379,7 +404,7 @@ def expect_block_sequence_item(self, first=False): self.state = self.states.pop() else: self.write_indent() - self.write_indicator('-', True, indention=True) + self.write_indicator("-", True, indention=True) self.states.append(self.expect_block_sequence_item) self.expect_node(sequence=True) @@ -402,37 +427,48 @@ def expect_block_mapping_key(self, first=False): self.states.append(self.expect_block_mapping_simple_value) self.expect_node(mapping=True, simple_key=True) else: - self.write_indicator('?', True, indention=True) + self.write_indicator("?", True, indention=True) self.states.append(self.expect_block_mapping_value) self.expect_node(mapping=True) def expect_block_mapping_simple_value(self): - self.write_indicator(':', False) + self.write_indicator(":", False) self.states.append(self.expect_block_mapping_key) self.expect_node(mapping=True) def expect_block_mapping_value(self): self.write_indent() - self.write_indicator(':', True, indention=True) + self.write_indicator(":", True, indention=True) self.states.append(self.expect_block_mapping_key) self.expect_node(mapping=True) # Checkers. def check_empty_sequence(self): - return (isinstance(self.event, SequenceStartEvent) and self.events - and isinstance(self.events[0], SequenceEndEvent)) + return ( + isinstance(self.event, SequenceStartEvent) + and self.events + and isinstance(self.events[0], SequenceEndEvent) + ) def check_empty_mapping(self): - return (isinstance(self.event, MappingStartEvent) and self.events - and isinstance(self.events[0], MappingEndEvent)) + return ( + isinstance(self.event, MappingStartEvent) + and self.events + and isinstance(self.events[0], MappingEndEvent) + ) def check_empty_document(self): if not isinstance(self.event, DocumentStartEvent) or not self.events: return False event = self.events[0] - return (isinstance(event, ScalarEvent) and event.anchor is None - and event.tag is None and event.implicit and event.value == '') + return ( + isinstance(event, ScalarEvent) + and event.anchor is None + and event.tag is None + and event.implicit + and event.value == "" + ) def check_simple_key(self): length = 0 @@ -440,8 +476,10 @@ def check_simple_key(self): if self.prepared_anchor is None: self.prepared_anchor = self.prepare_anchor(self.event.anchor) length += len(self.prepared_anchor) - if isinstance(self.event, (ScalarEvent, CollectionStartEvent)) \ - and self.event.tag is not None: + if ( + isinstance(self.event, (ScalarEvent, CollectionStartEvent)) + and self.event.tag is not None + ): if self.prepared_tag is None: self.prepared_tag = self.prepare_tag(self.event.tag) length += len(self.prepared_tag) @@ -449,10 +487,16 @@ def check_simple_key(self): if self.analysis is None: self.analysis = self.analyze_scalar(self.event.value) length += len(self.analysis.scalar) - return (length < 128 and (isinstance(self.event, AliasEvent) - or (isinstance(self.event, ScalarEvent) - and not self.analysis.empty and not self.analysis.multiline) - or self.check_empty_sequence() or self.check_empty_mapping())) + return length < 128 and ( + isinstance(self.event, AliasEvent) + or ( + isinstance(self.event, ScalarEvent) + and not self.analysis.empty + and not self.analysis.multiline + ) + or self.check_empty_sequence() + or self.check_empty_mapping() + ) # Anchor, Tag, and Scalar processors. @@ -463,7 +507,7 @@ def process_anchor(self, indicator): if self.prepared_anchor is None: self.prepared_anchor = self.prepare_anchor(self.event.anchor) if self.prepared_anchor: - self.write_indicator(indicator+self.prepared_anchor, True) + self.write_indicator(indicator + self.prepared_anchor, True) self.prepared_anchor = None def process_tag(self): @@ -471,13 +515,14 @@ def process_tag(self): if isinstance(self.event, ScalarEvent): if self.style is None: self.style = self.choose_scalar_style() - if ((not self.canonical or tag is None) and - ((self.style == '' and self.event.implicit[0]) - or (self.style != '' and self.event.implicit[1]))): + if (not self.canonical or tag is None) and ( + (self.style == "" and self.event.implicit[0]) + or (self.style != "" and self.event.implicit[1]) + ): self.prepared_tag = None return if self.event.implicit[0] and tag is None: - tag = '!' + tag = "!" self.prepared_tag = None else: if (not self.canonical or tag is None) and self.event.implicit: @@ -497,19 +542,27 @@ def choose_scalar_style(self): if self.event.style == '"' or self.canonical: return '"' if not self.event.style and self.event.implicit[0]: - if (not (self.simple_key_context and - (self.analysis.empty or self.analysis.multiline)) - and (self.flow_level and self.analysis.allow_flow_plain - or (not self.flow_level and self.analysis.allow_block_plain))): - return '' - if self.event.style and self.event.style in '|>': - if (not self.flow_level and not self.simple_key_context - and self.analysis.allow_block): + if not ( + self.simple_key_context + and (self.analysis.empty or self.analysis.multiline) + ) and ( + self.flow_level + and self.analysis.allow_flow_plain + or (not self.flow_level and self.analysis.allow_block_plain) + ): + return "" + if self.event.style and self.event.style in "|>": + if ( + not self.flow_level + and not self.simple_key_context + and self.analysis.allow_block + ): return self.event.style - if not self.event.style or self.event.style == '\'': - if (self.analysis.allow_single_quoted and - not (self.simple_key_context and self.analysis.multiline)): - return '\'' + if not self.event.style or self.event.style == "'": + if self.analysis.allow_single_quoted and not ( + self.simple_key_context and self.analysis.multiline + ): + return "'" return '"' def process_scalar(self): @@ -517,17 +570,17 @@ def process_scalar(self): self.analysis = self.analyze_scalar(self.event.value) if self.style is None: self.style = self.choose_scalar_style() - split = (not self.simple_key_context) - #if self.analysis.multiline and split \ + split = not self.simple_key_context + # if self.analysis.multiline and split \ # and (not self.style or self.style in '\'\"'): # self.write_indent() if self.style == '"': self.write_double_quoted(self.analysis.scalar, split) - elif self.style == '\'': + elif self.style == "'": self.write_single_quoted(self.analysis.scalar, split) - elif self.style == '>': + elif self.style == ">": self.write_folded(self.analysis.scalar) - elif self.style == '|': + elif self.style == "|": self.write_literal(self.analysis.scalar) else: self.write_plain(self.analysis.scalar, split) @@ -540,18 +593,20 @@ def prepare_version(self, version): major, minor = version if major != 1: raise EmitterError("unsupported YAML version: %d.%d" % (major, minor)) - return '%d.%d' % (major, minor) + return "%d.%d" % (major, minor) def prepare_tag_handle(self, handle): if not handle: raise EmitterError("tag handle must not be empty") - if handle[0] != '!' or handle[-1] != '!': + if handle[0] != "!" or handle[-1] != "!": raise EmitterError("tag handle must start and end with '!': %r" % handle) for ch in handle[1:-1]: - if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-_'): - raise EmitterError("invalid character %r in the tag handle: %r" - % (ch, handle)) + if not ( + "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_" + ): + raise EmitterError( + "invalid character %r in the tag handle: %r" % (ch, handle) + ) return handle def prepare_tag_prefix(self, prefix): @@ -559,78 +614,93 @@ def prepare_tag_prefix(self, prefix): raise EmitterError("tag prefix must not be empty") chunks = [] start = end = 0 - if prefix[0] == '!': + if prefix[0] == "!": end = 1 while end < len(prefix): ch = prefix[end] - if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-;/?!:@&=+$,_.~*\'()[]': + if ( + "0" <= ch <= "9" + or "A" <= ch <= "Z" + or "a" <= ch <= "z" + or ch in "-;/?!:@&=+$,_.~*'()[]" + ): end += 1 else: if start < end: chunks.append(prefix[start:end]) - start = end = end+1 - data = ch.encode('utf-8') + start = end = end + 1 + data = ch.encode("utf-8") for ch in data: - chunks.append('%%%02X' % ord(ch)) + chunks.append("%%%02X" % ord(ch)) if start < end: chunks.append(prefix[start:end]) - return ''.join(chunks) + return "".join(chunks) def prepare_tag(self, tag): if not tag: raise EmitterError("tag must not be empty") - if tag == '!': + if tag == "!": return tag handle = None suffix = tag prefixes = sorted(self.tag_prefixes.keys()) for prefix in prefixes: - if tag.startswith(prefix) \ - and (prefix == '!' or len(prefix) < len(tag)): + if tag.startswith(prefix) and (prefix == "!" or len(prefix) < len(tag)): handle = self.tag_prefixes[prefix] - suffix = tag[len(prefix):] + suffix = tag[len(prefix) :] chunks = [] start = end = 0 while end < len(suffix): ch = suffix[end] - if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-;/?:@&=+$,_.~*\'()[]' \ - or (ch == '!' and handle != '!'): + if ( + "0" <= ch <= "9" + or "A" <= ch <= "Z" + or "a" <= ch <= "z" + or ch in "-;/?:@&=+$,_.~*'()[]" + or (ch == "!" and handle != "!") + ): end += 1 else: if start < end: chunks.append(suffix[start:end]) - start = end = end+1 - data = ch.encode('utf-8') + start = end = end + 1 + data = ch.encode("utf-8") for ch in data: - chunks.append('%%%02X' % ch) + chunks.append("%%%02X" % ch) if start < end: chunks.append(suffix[start:end]) - suffix_text = ''.join(chunks) + suffix_text = "".join(chunks) if handle: - return '%s%s' % (handle, suffix_text) + return "%s%s" % (handle, suffix_text) else: - return '!<%s>' % suffix_text + return "!<%s>" % suffix_text def prepare_anchor(self, anchor): if not anchor: raise EmitterError("anchor must not be empty") for ch in anchor: - if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-_'): - raise EmitterError("invalid character %r in the anchor: %r" - % (ch, anchor)) + if not ( + "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_" + ): + raise EmitterError( + "invalid character %r in the anchor: %r" % (ch, anchor) + ) return anchor def analyze_scalar(self, scalar): # Empty scalar is a special case. if not scalar: - return ScalarAnalysis(scalar=scalar, empty=True, multiline=False, - allow_flow_plain=False, allow_block_plain=True, - allow_single_quoted=True, allow_double_quoted=True, - allow_block=False) + return ScalarAnalysis( + scalar=scalar, + empty=True, + multiline=False, + allow_flow_plain=False, + allow_block_plain=True, + allow_single_quoted=True, + allow_double_quoted=True, + allow_block=False, + ) # Indicators and special characters. block_indicators = False @@ -647,7 +717,7 @@ def analyze_scalar(self, scalar): space_break = False # Check document indicators. - if scalar.startswith('---') or scalar.startswith('...'): + if scalar.startswith("---") or scalar.startswith("..."): block_indicators = True flow_indicators = True @@ -655,8 +725,9 @@ def analyze_scalar(self, scalar): preceded_by_whitespace = True # Last character or followed by a whitespace. - followed_by_whitespace = (len(scalar) == 1 or - scalar[1] in '\0 \t\r\n\x85\u2028\u2029') + followed_by_whitespace = ( + len(scalar) == 1 or scalar[1] in "\0 \t\r\n\x85\u2028\u2029" + ) # The previous character is a space. previous_space = False @@ -671,35 +742,38 @@ def analyze_scalar(self, scalar): # Check for indicators. if index == 0: # Leading indicators are special characters. - if ch in '#,[]{}&*!|>\'\"%@`': + if ch in "#,[]{}&*!|>'\"%@`": flow_indicators = True block_indicators = True - if ch in '?:': + if ch in "?:": flow_indicators = True if followed_by_whitespace: block_indicators = True - if ch == '-' and followed_by_whitespace: + if ch == "-" and followed_by_whitespace: flow_indicators = True block_indicators = True else: # Some indicators cannot appear within a scalar as well. - if ch in ',?[]{}': + if ch in ",?[]{}": flow_indicators = True - if ch == ':': + if ch == ":": flow_indicators = True if followed_by_whitespace: block_indicators = True - if ch == '#' and preceded_by_whitespace: + if ch == "#" and preceded_by_whitespace: flow_indicators = True block_indicators = True # Check for line breaks, special, and unicode characters. - if ch in '\n\x85\u2028\u2029': + if ch in "\n\x85\u2028\u2029": line_breaks = True - if not (ch == '\n' or '\x20' <= ch <= '\x7E'): - if (ch == '\x85' or '\xA0' <= ch <= '\uD7FF' - or '\uE000' <= ch <= '\uFFFD' - or '\U00010000' <= ch < '\U0010ffff') and ch != '\uFEFF': + if not (ch == "\n" or "\x20" <= ch <= "\x7e"): + if ( + ch == "\x85" + or "\xa0" <= ch <= "\ud7ff" + or "\ue000" <= ch <= "\ufffd" + or "\U00010000" <= ch < "\U0010ffff" + ) and ch != "\ufeff": unicode_characters = True if not self.allow_unicode: special_characters = True @@ -707,19 +781,19 @@ def analyze_scalar(self, scalar): special_characters = True # Detect important whitespace combinations. - if ch == ' ': + if ch == " ": if index == 0: leading_space = True - if index == len(scalar)-1: + if index == len(scalar) - 1: trailing_space = True if previous_break: break_space = True previous_space = True previous_break = False - elif ch in '\n\x85\u2028\u2029': + elif ch in "\n\x85\u2028\u2029": if index == 0: leading_break = True - if index == len(scalar)-1: + if index == len(scalar) - 1: trailing_break = True if previous_space: space_break = True @@ -731,9 +805,11 @@ def analyze_scalar(self, scalar): # Prepare for the next character. index += 1 - preceded_by_whitespace = (ch in '\0 \t\r\n\x85\u2028\u2029') - followed_by_whitespace = (index+1 >= len(scalar) or - scalar[index+1] in '\0 \t\r\n\x85\u2028\u2029') + preceded_by_whitespace = ch in "\0 \t\r\n\x85\u2028\u2029" + followed_by_whitespace = ( + index + 1 >= len(scalar) + or scalar[index + 1] in "\0 \t\r\n\x85\u2028\u2029" + ) # Let's decide what styles are allowed. allow_flow_plain = True @@ -743,8 +819,7 @@ def analyze_scalar(self, scalar): allow_block = True # Leading and trailing whitespaces are bad for plain scalars. - if (leading_space or leading_break - or trailing_space or trailing_break): + if leading_space or leading_break or trailing_space or trailing_break: allow_flow_plain = allow_block_plain = False # We do not permit trailing spaces for block scalars. @@ -759,8 +834,9 @@ def analyze_scalar(self, scalar): # Spaces followed by breaks, as well as special character are only # allowed for double quoted scalars. if space_break or special_characters: - allow_flow_plain = allow_block_plain = \ - allow_single_quoted = allow_block = False + allow_flow_plain = allow_block_plain = allow_single_quoted = allow_block = ( + False + ) # Although the plain scalar writer supports breaks, we never emit # multiline plain scalars. @@ -775,34 +851,38 @@ def analyze_scalar(self, scalar): if block_indicators: allow_block_plain = False - return ScalarAnalysis(scalar=scalar, - empty=False, multiline=line_breaks, - allow_flow_plain=allow_flow_plain, - allow_block_plain=allow_block_plain, - allow_single_quoted=allow_single_quoted, - allow_double_quoted=allow_double_quoted, - allow_block=allow_block) + return ScalarAnalysis( + scalar=scalar, + empty=False, + multiline=line_breaks, + allow_flow_plain=allow_flow_plain, + allow_block_plain=allow_block_plain, + allow_single_quoted=allow_single_quoted, + allow_double_quoted=allow_double_quoted, + allow_block=allow_block, + ) # Writers. def flush_stream(self): - if hasattr(self.stream, 'flush'): + if hasattr(self.stream, "flush"): self.stream.flush() def write_stream_start(self): # Write BOM if needed. - if self.encoding and self.encoding.startswith('utf-16'): - self.stream.write('\uFEFF'.encode(self.encoding)) + if self.encoding and self.encoding.startswith("utf-16"): + self.stream.write("\ufeff".encode(self.encoding)) def write_stream_end(self): self.flush_stream() - def write_indicator(self, indicator, need_whitespace, - whitespace=False, indention=False): + def write_indicator( + self, indicator, need_whitespace, whitespace=False, indention=False + ): if self.whitespace or not need_whitespace: data = indicator else: - data = ' '+indicator + data = " " + indicator self.whitespace = whitespace self.indention = self.indention and indention self.column += len(data) @@ -813,12 +893,15 @@ def write_indicator(self, indicator, need_whitespace, def write_indent(self): indent = self.indent or 0 - if not self.indention or self.column > indent \ - or (self.column == indent and not self.whitespace): + if ( + not self.indention + or self.column > indent + or (self.column == indent and not self.whitespace) + ): self.write_line_break() if self.column < indent: self.whitespace = True - data = ' '*(indent-self.column) + data = " " * (indent - self.column) self.column = indent if self.encoding: data = data.encode(self.encoding) @@ -836,14 +919,14 @@ def write_line_break(self, data=None): self.stream.write(data) def write_version_directive(self, version_text): - data = '%%YAML %s' % version_text + data = "%%YAML %s" % version_text if self.encoding: data = data.encode(self.encoding) self.stream.write(data) self.write_line_break() def write_tag_directive(self, handle_text, prefix_text): - data = '%%TAG %s %s' % (handle_text, prefix_text) + data = "%%TAG %s %s" % (handle_text, prefix_text) if self.encoding: data = data.encode(self.encoding) self.stream.write(data) @@ -852,7 +935,7 @@ def write_tag_directive(self, handle_text, prefix_text): # Scalar streams. def write_single_quoted(self, text, split=True): - self.write_indicator('\'', True) + self.write_indicator("'", True) spaces = False breaks = False start = end = 0 @@ -861,9 +944,14 @@ def write_single_quoted(self, text, split=True): if end < len(text): ch = text[end] if spaces: - if ch is None or ch != ' ': - if start+1 == end and self.column > self.best_width and split \ - and start != 0 and end != len(text): + if ch is None or ch != " ": + if ( + start + 1 == end + and self.column > self.best_width + and split + and start != 0 + and end != len(text) + ): self.write_indent() else: data = text[start:end] @@ -873,18 +961,18 @@ def write_single_quoted(self, text, split=True): self.stream.write(data) start = end elif breaks: - if ch is None or ch not in '\n\x85\u2028\u2029': - if text[start] == '\n': + if ch is None or ch not in "\n\x85\u2028\u2029": + if text[start] == "\n": self.write_line_break() for br in text[start:end]: - if br == '\n': + if br == "\n": self.write_line_break() else: self.write_line_break(br) self.write_indent() start = end else: - if ch is None or ch in ' \n\x85\u2028\u2029' or ch == '\'': + if ch is None or ch in " \n\x85\u2028\u2029" or ch == "'": if start < end: data = text[start:end] self.column += len(data) @@ -892,35 +980,35 @@ def write_single_quoted(self, text, split=True): data = data.encode(self.encoding) self.stream.write(data) start = end - if ch == '\'': - data = '\'\'' + if ch == "'": + data = "''" self.column += 2 if self.encoding: data = data.encode(self.encoding) self.stream.write(data) start = end + 1 if ch is not None: - spaces = (ch == ' ') - breaks = (ch in '\n\x85\u2028\u2029') + spaces = ch == " " + breaks = ch in "\n\x85\u2028\u2029" end += 1 - self.write_indicator('\'', False) + self.write_indicator("'", False) ESCAPE_REPLACEMENTS = { - '\0': '0', - '\x07': 'a', - '\x08': 'b', - '\x09': 't', - '\x0A': 'n', - '\x0B': 'v', - '\x0C': 'f', - '\x0D': 'r', - '\x1B': 'e', - '\"': '\"', - '\\': '\\', - '\x85': 'N', - '\xA0': '_', - '\u2028': 'L', - '\u2029': 'P', + "\0": "0", + "\x07": "a", + "\x08": "b", + "\x09": "t", + "\x0a": "n", + "\x0b": "v", + "\x0c": "f", + "\x0d": "r", + "\x1b": "e", + '"': '"', + "\\": "\\", + "\x85": "N", + "\xa0": "_", + "\u2028": "L", + "\u2029": "P", } def write_double_quoted(self, text, split=True): @@ -930,11 +1018,17 @@ def write_double_quoted(self, text, split=True): ch = None if end < len(text): ch = text[end] - if ch is None or ch in '"\\\x85\u2028\u2029\uFEFF' \ - or not ('\x20' <= ch <= '\x7E' - or (self.allow_unicode - and ('\xA0' <= ch <= '\uD7FF' - or '\uE000' <= ch <= '\uFFFD'))): + if ( + ch is None + or ch in '"\\\x85\u2028\u2029\ufeff' + or not ( + "\x20" <= ch <= "\x7e" + or ( + self.allow_unicode + and ("\xa0" <= ch <= "\ud7ff" or "\ue000" <= ch <= "\ufffd") + ) + ) + ): if start < end: data = text[start:end] self.column += len(data) @@ -944,21 +1038,25 @@ def write_double_quoted(self, text, split=True): start = end if ch is not None: if ch in self.ESCAPE_REPLACEMENTS: - data = '\\'+self.ESCAPE_REPLACEMENTS[ch] - elif ch <= '\xFF': - data = '\\x%02X' % ord(ch) - elif ch <= '\uFFFF': - data = '\\u%04X' % ord(ch) + data = "\\" + self.ESCAPE_REPLACEMENTS[ch] + elif ch <= "\xff": + data = "\\x%02X" % ord(ch) + elif ch <= "\uffff": + data = "\\u%04X" % ord(ch) else: - data = '\\U%08X' % ord(ch) + data = "\\U%08X" % ord(ch) self.column += len(data) if self.encoding: data = data.encode(self.encoding) self.stream.write(data) - start = end+1 - if 0 < end < len(text)-1 and (ch == ' ' or start >= end) \ - and self.column+(end-start) > self.best_width and split: - data = text[start:end]+'\\' + start = end + 1 + if ( + 0 < end < len(text) - 1 + and (ch == " " or start >= end) + and self.column + (end - start) > self.best_width + and split + ): + data = text[start:end] + "\\" if start < end: start = end self.column += len(data) @@ -968,8 +1066,8 @@ def write_double_quoted(self, text, split=True): self.write_indent() self.whitespace = False self.indention = False - if text[start] == ' ': - data = '\\' + if text[start] == " ": + data = "\\" self.column += len(data) if self.encoding: data = data.encode(self.encoding) @@ -978,20 +1076,20 @@ def write_double_quoted(self, text, split=True): self.write_indicator('"', False) def determine_block_hints(self, text): - hints = '' + hints = "" if text: - if text[0] in ' \n\x85\u2028\u2029': + if text[0] in " \n\x85\u2028\u2029": hints += str(self.best_indent) - if text[-1] not in '\n\x85\u2028\u2029': - hints += '-' - elif len(text) == 1 or text[-2] in '\n\x85\u2028\u2029': - hints += '+' + if text[-1] not in "\n\x85\u2028\u2029": + hints += "-" + elif len(text) == 1 or text[-2] in "\n\x85\u2028\u2029": + hints += "+" return hints def write_folded(self, text): hints = self.determine_block_hints(text) - self.write_indicator('>'+hints, True) - if hints[-1:] == '+': + self.write_indicator(">" + hints, True) + if hints[-1:] == "+": self.open_ended = True self.write_line_break() leading_space = True @@ -1003,13 +1101,17 @@ def write_folded(self, text): if end < len(text): ch = text[end] if breaks: - if ch is None or ch not in '\n\x85\u2028\u2029': - if not leading_space and ch is not None and ch != ' ' \ - and text[start] == '\n': + if ch is None or ch not in "\n\x85\u2028\u2029": + if ( + not leading_space + and ch is not None + and ch != " " + and text[start] == "\n" + ): self.write_line_break() - leading_space = (ch == ' ') + leading_space = ch == " " for br in text[start:end]: - if br == '\n': + if br == "\n": self.write_line_break() else: self.write_line_break(br) @@ -1017,8 +1119,8 @@ def write_folded(self, text): self.write_indent() start = end elif spaces: - if ch != ' ': - if start+1 == end and self.column > self.best_width: + if ch != " ": + if start + 1 == end and self.column > self.best_width: self.write_indent() else: data = text[start:end] @@ -1028,7 +1130,7 @@ def write_folded(self, text): self.stream.write(data) start = end else: - if ch is None or ch in ' \n\x85\u2028\u2029': + if ch is None or ch in " \n\x85\u2028\u2029": data = text[start:end] self.column += len(data) if self.encoding: @@ -1038,14 +1140,14 @@ def write_folded(self, text): self.write_line_break() start = end if ch is not None: - breaks = (ch in '\n\x85\u2028\u2029') - spaces = (ch == ' ') + breaks = ch in "\n\x85\u2028\u2029" + spaces = ch == " " end += 1 def write_literal(self, text): hints = self.determine_block_hints(text) - self.write_indicator('|'+hints, True) - if hints[-1:] == '+': + self.write_indicator("|" + hints, True) + if hints[-1:] == "+": self.open_ended = True self.write_line_break() breaks = True @@ -1055,9 +1157,9 @@ def write_literal(self, text): if end < len(text): ch = text[end] if breaks: - if ch is None or ch not in '\n\x85\u2028\u2029': + if ch is None or ch not in "\n\x85\u2028\u2029": for br in text[start:end]: - if br == '\n': + if br == "\n": self.write_line_break() else: self.write_line_break(br) @@ -1065,7 +1167,7 @@ def write_literal(self, text): self.write_indent() start = end else: - if ch is None or ch in '\n\x85\u2028\u2029': + if ch is None or ch in "\n\x85\u2028\u2029": data = text[start:end] if self.encoding: data = data.encode(self.encoding) @@ -1074,7 +1176,7 @@ def write_literal(self, text): self.write_line_break() start = end if ch is not None: - breaks = (ch in '\n\x85\u2028\u2029') + breaks = ch in "\n\x85\u2028\u2029" end += 1 def write_plain(self, text, split=True): @@ -1083,7 +1185,7 @@ def write_plain(self, text, split=True): if not text: return if not self.whitespace: - data = ' ' + data = " " self.column += len(data) if self.encoding: data = data.encode(self.encoding) @@ -1098,8 +1200,8 @@ def write_plain(self, text, split=True): if end < len(text): ch = text[end] if spaces: - if ch != ' ': - if start+1 == end and self.column > self.best_width and split: + if ch != " ": + if start + 1 == end and self.column > self.best_width and split: self.write_indent() self.whitespace = False self.indention = False @@ -1111,11 +1213,11 @@ def write_plain(self, text, split=True): self.stream.write(data) start = end elif breaks: - if ch not in '\n\x85\u2028\u2029': - if text[start] == '\n': + if ch not in "\n\x85\u2028\u2029": + if text[start] == "\n": self.write_line_break() for br in text[start:end]: - if br == '\n': + if br == "\n": self.write_line_break() else: self.write_line_break(br) @@ -1124,7 +1226,7 @@ def write_plain(self, text, split=True): self.indention = False start = end else: - if ch is None or ch in ' \n\x85\u2028\u2029': + if ch is None or ch in " \n\x85\u2028\u2029": data = text[start:end] self.column += len(data) if self.encoding: @@ -1132,6 +1234,6 @@ def write_plain(self, text, split=True): self.stream.write(data) start = end if ch is not None: - spaces = (ch == ' ') - breaks = (ch in '\n\x85\u2028\u2029') + spaces = ch == " " + breaks = ch in "\n\x85\u2028\u2029" end += 1 diff --git a/metaflow/_vendor/yaml/error.py b/metaflow/_vendor/yaml/error.py index b796b4dc519..2b84f4c76aa 100644 --- a/metaflow/_vendor/yaml/error.py +++ b/metaflow/_vendor/yaml/error.py @@ -1,5 +1,5 @@ +__all__ = ["Mark", "YAMLError", "MarkedYAMLError"] -__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError'] class Mark: @@ -14,41 +14,61 @@ def __init__(self, name, index, line, column, buffer, pointer): def get_snippet(self, indent=4, max_length=75): if self.buffer is None: return None - head = '' + head = "" start = self.pointer - while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029': + while start > 0 and self.buffer[start - 1] not in "\0\r\n\x85\u2028\u2029": start -= 1 - if self.pointer-start > max_length/2-1: - head = ' ... ' + if self.pointer - start > max_length / 2 - 1: + head = " ... " start += 5 break - tail = '' + tail = "" end = self.pointer - while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029': + while ( + end < len(self.buffer) and self.buffer[end] not in "\0\r\n\x85\u2028\u2029" + ): end += 1 - if end-self.pointer > max_length/2-1: - tail = ' ... ' + if end - self.pointer > max_length / 2 - 1: + tail = " ... " end -= 5 break snippet = self.buffer[start:end] - return ' '*indent + head + snippet + tail + '\n' \ - + ' '*(indent+self.pointer-start+len(head)) + '^' + return ( + " " * indent + + head + + snippet + + tail + + "\n" + + " " * (indent + self.pointer - start + len(head)) + + "^" + ) def __str__(self): snippet = self.get_snippet() - where = " in \"%s\", line %d, column %d" \ - % (self.name, self.line+1, self.column+1) + where = ' in "%s", line %d, column %d' % ( + self.name, + self.line + 1, + self.column + 1, + ) if snippet is not None: - where += ":\n"+snippet + where += ":\n" + snippet return where + class YAMLError(Exception): pass + class MarkedYAMLError(YAMLError): - def __init__(self, context=None, context_mark=None, - problem=None, problem_mark=None, note=None): + def __init__( + self, + context=None, + context_mark=None, + problem=None, + problem_mark=None, + note=None, + ): self.context = context self.context_mark = context_mark self.problem = problem @@ -59,11 +79,13 @@ def __str__(self): lines = [] if self.context is not None: lines.append(self.context) - if self.context_mark is not None \ - and (self.problem is None or self.problem_mark is None - or self.context_mark.name != self.problem_mark.name - or self.context_mark.line != self.problem_mark.line - or self.context_mark.column != self.problem_mark.column): + if self.context_mark is not None and ( + self.problem is None + or self.problem_mark is None + or self.context_mark.name != self.problem_mark.name + or self.context_mark.line != self.problem_mark.line + or self.context_mark.column != self.problem_mark.column + ): lines.append(str(self.context_mark)) if self.problem is not None: lines.append(self.problem) @@ -71,5 +93,4 @@ def __str__(self): lines.append(str(self.problem_mark)) if self.note is not None: lines.append(self.note) - return '\n'.join(lines) - + return "\n".join(lines) diff --git a/metaflow/_vendor/yaml/events.py b/metaflow/_vendor/yaml/events.py index f79ad389cb6..b2e31472e53 100644 --- a/metaflow/_vendor/yaml/events.py +++ b/metaflow/_vendor/yaml/events.py @@ -1,16 +1,20 @@ - # Abstract classes. + class Event(object): def __init__(self, start_mark=None, end_mark=None): self.start_mark = start_mark self.end_mark = end_mark + def __repr__(self): - attributes = [key for key in ['anchor', 'tag', 'implicit', 'value'] - if hasattr(self, key)] - arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) - for key in attributes]) - return '%s(%s)' % (self.__class__.__name__, arguments) + attributes = [ + key for key in ["anchor", "tag", "implicit", "value"] if hasattr(self, key) + ] + arguments = ", ".join( + ["%s=%r" % (key, getattr(self, key)) for key in attributes] + ) + return "%s(%s)" % (self.__class__.__name__, arguments) + class NodeEvent(Event): def __init__(self, anchor, start_mark=None, end_mark=None): @@ -18,9 +22,11 @@ def __init__(self, anchor, start_mark=None, end_mark=None): self.start_mark = start_mark self.end_mark = end_mark + class CollectionStartEvent(NodeEvent): - def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, - flow_style=None): + def __init__( + self, anchor, tag, implicit, start_mark=None, end_mark=None, flow_style=None + ): self.anchor = anchor self.tag = tag self.implicit = implicit @@ -28,42 +34,51 @@ def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, self.end_mark = end_mark self.flow_style = flow_style + class CollectionEndEvent(Event): pass + # Implementations. + class StreamStartEvent(Event): def __init__(self, start_mark=None, end_mark=None, encoding=None): self.start_mark = start_mark self.end_mark = end_mark self.encoding = encoding + class StreamEndEvent(Event): pass + class DocumentStartEvent(Event): - def __init__(self, start_mark=None, end_mark=None, - explicit=None, version=None, tags=None): + def __init__( + self, start_mark=None, end_mark=None, explicit=None, version=None, tags=None + ): self.start_mark = start_mark self.end_mark = end_mark self.explicit = explicit self.version = version self.tags = tags + class DocumentEndEvent(Event): - def __init__(self, start_mark=None, end_mark=None, - explicit=None): + def __init__(self, start_mark=None, end_mark=None, explicit=None): self.start_mark = start_mark self.end_mark = end_mark self.explicit = explicit + class AliasEvent(NodeEvent): pass + class ScalarEvent(NodeEvent): - def __init__(self, anchor, tag, implicit, value, - start_mark=None, end_mark=None, style=None): + def __init__( + self, anchor, tag, implicit, value, start_mark=None, end_mark=None, style=None + ): self.anchor = anchor self.tag = tag self.implicit = implicit @@ -72,15 +87,18 @@ def __init__(self, anchor, tag, implicit, value, self.end_mark = end_mark self.style = style + class SequenceStartEvent(CollectionStartEvent): pass + class SequenceEndEvent(CollectionEndEvent): pass + class MappingStartEvent(CollectionStartEvent): pass + class MappingEndEvent(CollectionEndEvent): pass - diff --git a/metaflow/_vendor/yaml/loader.py b/metaflow/_vendor/yaml/loader.py index e90c11224c3..7200fcbc1ae 100644 --- a/metaflow/_vendor/yaml/loader.py +++ b/metaflow/_vendor/yaml/loader.py @@ -1,5 +1,4 @@ - -__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader'] +__all__ = ["BaseLoader", "FullLoader", "SafeLoader", "Loader", "UnsafeLoader"] from .reader import * from .scanner import * @@ -8,6 +7,7 @@ from .constructor import * from .resolver import * + class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver): def __init__(self, stream): @@ -18,6 +18,7 @@ def __init__(self, stream): BaseConstructor.__init__(self) BaseResolver.__init__(self) + class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver): def __init__(self, stream): @@ -28,6 +29,7 @@ def __init__(self, stream): FullConstructor.__init__(self) Resolver.__init__(self) + class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): def __init__(self, stream): @@ -38,6 +40,7 @@ def __init__(self, stream): SafeConstructor.__init__(self) Resolver.__init__(self) + class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver): def __init__(self, stream): @@ -48,6 +51,7 @@ def __init__(self, stream): Constructor.__init__(self) Resolver.__init__(self) + # UnsafeLoader is the same as Loader (which is and was always unsafe on # untrusted input). Use of either Loader or UnsafeLoader should be rare, since # FullLoad should be able to load almost all YAML safely. Loader is left intact diff --git a/metaflow/_vendor/yaml/nodes.py b/metaflow/_vendor/yaml/nodes.py index c4f070c41e1..ad8a4bb9b74 100644 --- a/metaflow/_vendor/yaml/nodes.py +++ b/metaflow/_vendor/yaml/nodes.py @@ -1,49 +1,51 @@ - class Node(object): def __init__(self, tag, value, start_mark, end_mark): self.tag = tag self.value = value self.start_mark = start_mark self.end_mark = end_mark + def __repr__(self): value = self.value - #if isinstance(value, list): + # if isinstance(value, list): # if len(value) == 0: # value = '' # elif len(value) == 1: # value = '<1 item>' # else: # value = '<%d items>' % len(value) - #else: + # else: # if len(value) > 75: # value = repr(value[:70]+u' ... ') # else: # value = repr(value) value = repr(value) - return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) + return "%s(tag=%r, value=%s)" % (self.__class__.__name__, self.tag, value) + class ScalarNode(Node): - id = 'scalar' - def __init__(self, tag, value, - start_mark=None, end_mark=None, style=None): + id = "scalar" + + def __init__(self, tag, value, start_mark=None, end_mark=None, style=None): self.tag = tag self.value = value self.start_mark = start_mark self.end_mark = end_mark self.style = style + class CollectionNode(Node): - def __init__(self, tag, value, - start_mark=None, end_mark=None, flow_style=None): + def __init__(self, tag, value, start_mark=None, end_mark=None, flow_style=None): self.tag = tag self.value = value self.start_mark = start_mark self.end_mark = end_mark self.flow_style = flow_style + class SequenceNode(CollectionNode): - id = 'sequence' + id = "sequence" -class MappingNode(CollectionNode): - id = 'mapping' +class MappingNode(CollectionNode): + id = "mapping" diff --git a/metaflow/_vendor/yaml/parser.py b/metaflow/_vendor/yaml/parser.py index 13a5995d292..9850645cb55 100644 --- a/metaflow/_vendor/yaml/parser.py +++ b/metaflow/_vendor/yaml/parser.py @@ -1,4 +1,3 @@ - # The following YAML grammar is LL(1) and is parsed by a recursive descent # parser. # @@ -59,23 +58,25 @@ # flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } # flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } -__all__ = ['Parser', 'ParserError'] +__all__ = ["Parser", "ParserError"] from .error import MarkedYAMLError from .tokens import * from .events import * from .scanner import * + class ParserError(MarkedYAMLError): pass + class Parser: # Since writing a recursive-descendant parser is a straightforward task, we # do not give many comments here. DEFAULT_TAGS = { - '!': '!', - '!!': 'tag:yaml.org,2002:', + "!": "!", + "!!": "tag:yaml.org,2002:", } def __init__(self): @@ -128,8 +129,9 @@ def parse_stream_start(self): # Parse the stream start. token = self.get_token() - event = StreamStartEvent(token.start_mark, token.end_mark, - encoding=token.encoding) + event = StreamStartEvent( + token.start_mark, token.end_mark, encoding=token.encoding + ) # Prepare the next state. self.state = self.parse_implicit_document_start @@ -139,13 +141,11 @@ def parse_stream_start(self): def parse_implicit_document_start(self): # Parse an implicit document. - if not self.check_token(DirectiveToken, DocumentStartToken, - StreamEndToken): + if not self.check_token(DirectiveToken, DocumentStartToken, StreamEndToken): self.tag_handles = self.DEFAULT_TAGS token = self.peek_token() start_mark = end_mark = token.start_mark - event = DocumentStartEvent(start_mark, end_mark, - explicit=False) + event = DocumentStartEvent(start_mark, end_mark, explicit=False) # Prepare the next state. self.states.append(self.parse_document_end) @@ -168,14 +168,17 @@ def parse_document_start(self): start_mark = token.start_mark version, tags = self.process_directives() if not self.check_token(DocumentStartToken): - raise ParserError(None, None, - "expected '', but found %r" - % self.peek_token().id, - self.peek_token().start_mark) + raise ParserError( + None, + None, + "expected '', but found %r" % self.peek_token().id, + self.peek_token().start_mark, + ) token = self.get_token() end_mark = token.end_mark - event = DocumentStartEvent(start_mark, end_mark, - explicit=True, version=version, tags=tags) + event = DocumentStartEvent( + start_mark, end_mark, explicit=True, version=version, tags=tags + ) self.states.append(self.parse_document_end) self.state = self.parse_document_content else: @@ -197,8 +200,7 @@ def parse_document_end(self): token = self.get_token() end_mark = token.end_mark explicit = True - event = DocumentEndEvent(start_mark, end_mark, - explicit=explicit) + event = DocumentEndEvent(start_mark, end_mark, explicit=explicit) # Prepare the next state. self.state = self.parse_document_start @@ -206,8 +208,9 @@ def parse_document_end(self): return event def parse_document_content(self): - if self.check_token(DirectiveToken, - DocumentStartToken, DocumentEndToken, StreamEndToken): + if self.check_token( + DirectiveToken, DocumentStartToken, DocumentEndToken, StreamEndToken + ): event = self.process_empty_scalar(self.peek_token().start_mark) self.state = self.states.pop() return event @@ -219,22 +222,26 @@ def process_directives(self): self.tag_handles = {} while self.check_token(DirectiveToken): token = self.get_token() - if token.name == 'YAML': + if token.name == "YAML": if self.yaml_version is not None: - raise ParserError(None, None, - "found duplicate YAML directive", token.start_mark) + raise ParserError( + None, None, "found duplicate YAML directive", token.start_mark + ) major, minor = token.value if major != 1: - raise ParserError(None, None, - "found incompatible YAML document (version 1.* is required)", - token.start_mark) + raise ParserError( + None, + None, + "found incompatible YAML document (version 1.* is required)", + token.start_mark, + ) self.yaml_version = token.value - elif token.name == 'TAG': + elif token.name == "TAG": handle, prefix = token.value if handle in self.tag_handles: - raise ParserError(None, None, - "duplicate tag handle %r" % handle, - token.start_mark) + raise ParserError( + None, None, "duplicate tag handle %r" % handle, token.start_mark + ) self.tag_handles[handle] = prefix if self.tag_handles: value = self.yaml_version, self.tag_handles.copy() @@ -302,73 +309,90 @@ def parse_node(self, block=False, indentless_sequence=False): handle, suffix = tag if handle is not None: if handle not in self.tag_handles: - raise ParserError("while parsing a node", start_mark, - "found undefined tag handle %r" % handle, - tag_mark) - tag = self.tag_handles[handle]+suffix + raise ParserError( + "while parsing a node", + start_mark, + "found undefined tag handle %r" % handle, + tag_mark, + ) + tag = self.tag_handles[handle] + suffix else: tag = suffix - #if tag == '!': + # if tag == '!': # raise ParserError("while parsing a node", start_mark, # "found non-specific tag '!'", tag_mark, # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.") if start_mark is None: start_mark = end_mark = self.peek_token().start_mark event = None - implicit = (tag is None or tag == '!') + implicit = tag is None or tag == "!" if indentless_sequence and self.check_token(BlockEntryToken): end_mark = self.peek_token().end_mark - event = SequenceStartEvent(anchor, tag, implicit, - start_mark, end_mark) + event = SequenceStartEvent(anchor, tag, implicit, start_mark, end_mark) self.state = self.parse_indentless_sequence_entry else: if self.check_token(ScalarToken): token = self.get_token() end_mark = token.end_mark - if (token.plain and tag is None) or tag == '!': + if (token.plain and tag is None) or tag == "!": implicit = (True, False) elif tag is None: implicit = (False, True) else: implicit = (False, False) - event = ScalarEvent(anchor, tag, implicit, token.value, - start_mark, end_mark, style=token.style) + event = ScalarEvent( + anchor, + tag, + implicit, + token.value, + start_mark, + end_mark, + style=token.style, + ) self.state = self.states.pop() elif self.check_token(FlowSequenceStartToken): end_mark = self.peek_token().end_mark - event = SequenceStartEvent(anchor, tag, implicit, - start_mark, end_mark, flow_style=True) + event = SequenceStartEvent( + anchor, tag, implicit, start_mark, end_mark, flow_style=True + ) self.state = self.parse_flow_sequence_first_entry elif self.check_token(FlowMappingStartToken): end_mark = self.peek_token().end_mark - event = MappingStartEvent(anchor, tag, implicit, - start_mark, end_mark, flow_style=True) + event = MappingStartEvent( + anchor, tag, implicit, start_mark, end_mark, flow_style=True + ) self.state = self.parse_flow_mapping_first_key elif block and self.check_token(BlockSequenceStartToken): end_mark = self.peek_token().start_mark - event = SequenceStartEvent(anchor, tag, implicit, - start_mark, end_mark, flow_style=False) + event = SequenceStartEvent( + anchor, tag, implicit, start_mark, end_mark, flow_style=False + ) self.state = self.parse_block_sequence_first_entry elif block and self.check_token(BlockMappingStartToken): end_mark = self.peek_token().start_mark - event = MappingStartEvent(anchor, tag, implicit, - start_mark, end_mark, flow_style=False) + event = MappingStartEvent( + anchor, tag, implicit, start_mark, end_mark, flow_style=False + ) self.state = self.parse_block_mapping_first_key elif anchor is not None or tag is not None: # Empty scalars are allowed even if a tag or an anchor is # specified. - event = ScalarEvent(anchor, tag, (implicit, False), '', - start_mark, end_mark) + event = ScalarEvent( + anchor, tag, (implicit, False), "", start_mark, end_mark + ) self.state = self.states.pop() else: if block: - node = 'block' + node = "block" else: - node = 'flow' + node = "flow" token = self.peek_token() - raise ParserError("while parsing a %s node" % node, start_mark, - "expected the node content, but found %r" % token.id, - token.start_mark) + raise ParserError( + "while parsing a %s node" % node, + start_mark, + "expected the node content, but found %r" % token.id, + token.start_mark, + ) return event # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END @@ -389,8 +413,12 @@ def parse_block_sequence_entry(self): return self.process_empty_scalar(token.end_mark) if not self.check_token(BlockEndToken): token = self.peek_token() - raise ParserError("while parsing a block collection", self.marks[-1], - "expected , but found %r" % token.id, token.start_mark) + raise ParserError( + "while parsing a block collection", + self.marks[-1], + "expected , but found %r" % token.id, + token.start_mark, + ) token = self.get_token() event = SequenceEndEvent(token.start_mark, token.end_mark) self.state = self.states.pop() @@ -402,8 +430,9 @@ def parse_block_sequence_entry(self): def parse_indentless_sequence_entry(self): if self.check_token(BlockEntryToken): token = self.get_token() - if not self.check_token(BlockEntryToken, - KeyToken, ValueToken, BlockEndToken): + if not self.check_token( + BlockEntryToken, KeyToken, ValueToken, BlockEndToken + ): self.states.append(self.parse_indentless_sequence_entry) return self.parse_block_node() else: @@ -435,8 +464,12 @@ def parse_block_mapping_key(self): return self.process_empty_scalar(token.end_mark) if not self.check_token(BlockEndToken): token = self.peek_token() - raise ParserError("while parsing a block mapping", self.marks[-1], - "expected , but found %r" % token.id, token.start_mark) + raise ParserError( + "while parsing a block mapping", + self.marks[-1], + "expected , but found %r" % token.id, + token.start_mark, + ) token = self.get_token() event = MappingEndEvent(token.start_mark, token.end_mark) self.state = self.states.pop() @@ -480,14 +513,18 @@ def parse_flow_sequence_entry(self, first=False): self.get_token() else: token = self.peek_token() - raise ParserError("while parsing a flow sequence", self.marks[-1], - "expected ',' or ']', but got %r" % token.id, token.start_mark) - + raise ParserError( + "while parsing a flow sequence", + self.marks[-1], + "expected ',' or ']', but got %r" % token.id, + token.start_mark, + ) + if self.check_token(KeyToken): token = self.peek_token() - event = MappingStartEvent(None, None, True, - token.start_mark, token.end_mark, - flow_style=True) + event = MappingStartEvent( + None, None, True, token.start_mark, token.end_mark, flow_style=True + ) self.state = self.parse_flow_sequence_entry_mapping_key return event elif not self.check_token(FlowSequenceEndToken): @@ -501,8 +538,7 @@ def parse_flow_sequence_entry(self, first=False): def parse_flow_sequence_entry_mapping_key(self): token = self.get_token() - if not self.check_token(ValueToken, - FlowEntryToken, FlowSequenceEndToken): + if not self.check_token(ValueToken, FlowEntryToken, FlowSequenceEndToken): self.states.append(self.parse_flow_sequence_entry_mapping_value) return self.parse_flow_node() else: @@ -546,12 +582,17 @@ def parse_flow_mapping_key(self, first=False): self.get_token() else: token = self.peek_token() - raise ParserError("while parsing a flow mapping", self.marks[-1], - "expected ',' or '}', but got %r" % token.id, token.start_mark) + raise ParserError( + "while parsing a flow mapping", + self.marks[-1], + "expected ',' or '}', but got %r" % token.id, + token.start_mark, + ) if self.check_token(KeyToken): token = self.get_token() - if not self.check_token(ValueToken, - FlowEntryToken, FlowMappingEndToken): + if not self.check_token( + ValueToken, FlowEntryToken, FlowMappingEndToken + ): self.states.append(self.parse_flow_mapping_value) return self.parse_flow_node() else: @@ -585,5 +626,4 @@ def parse_flow_mapping_empty_value(self): return self.process_empty_scalar(self.peek_token().start_mark) def process_empty_scalar(self, mark): - return ScalarEvent(None, None, (True, False), '', mark, mark) - + return ScalarEvent(None, None, (True, False), "", mark, mark) diff --git a/metaflow/_vendor/yaml/reader.py b/metaflow/_vendor/yaml/reader.py index 774b0219b59..006065f07ff 100644 --- a/metaflow/_vendor/yaml/reader.py +++ b/metaflow/_vendor/yaml/reader.py @@ -15,12 +15,13 @@ # reader.index - the number of the current character. # reader.line, stream.column - the line and the column of the current character. -__all__ = ['Reader', 'ReaderError'] +__all__ = ["Reader", "ReaderError"] from .error import YAMLError, Mark import codecs, re + class ReaderError(YAMLError): def __init__(self, name, position, character, encoding, reason): @@ -32,15 +33,25 @@ def __init__(self, name, position, character, encoding, reason): def __str__(self): if isinstance(self.character, bytes): - return "'%s' codec can't decode byte #x%02x: %s\n" \ - " in \"%s\", position %d" \ - % (self.encoding, ord(self.character), self.reason, - self.name, self.position) + return ( + "'%s' codec can't decode byte #x%02x: %s\n" + ' in "%s", position %d' + % ( + self.encoding, + ord(self.character), + self.reason, + self.name, + self.position, + ) + ) else: - return "unacceptable character #x%04x: %s\n" \ - " in \"%s\", position %d" \ - % (self.character, self.reason, - self.name, self.position) + return "unacceptable character #x%04x: %s\n" ' in "%s", position %d' % ( + self.character, + self.reason, + self.name, + self.position, + ) + class Reader(object): # Reader: @@ -61,7 +72,7 @@ def __init__(self, stream): self.stream = None self.stream_pointer = 0 self.eof = True - self.buffer = '' + self.buffer = "" self.pointer = 0 self.raw_buffer = None self.raw_decode = None @@ -72,52 +83,53 @@ def __init__(self, stream): if isinstance(stream, str): self.name = "" self.check_printable(stream) - self.buffer = stream+'\0' + self.buffer = stream + "\0" elif isinstance(stream, bytes): self.name = "" self.raw_buffer = stream self.determine_encoding() else: self.stream = stream - self.name = getattr(stream, 'name', "") + self.name = getattr(stream, "name", "") self.eof = False self.raw_buffer = None self.determine_encoding() def peek(self, index=0): try: - return self.buffer[self.pointer+index] + return self.buffer[self.pointer + index] except IndexError: - self.update(index+1) - return self.buffer[self.pointer+index] + self.update(index + 1) + return self.buffer[self.pointer + index] def prefix(self, length=1): - if self.pointer+length >= len(self.buffer): + if self.pointer + length >= len(self.buffer): self.update(length) - return self.buffer[self.pointer:self.pointer+length] + return self.buffer[self.pointer : self.pointer + length] def forward(self, length=1): - if self.pointer+length+1 >= len(self.buffer): - self.update(length+1) + if self.pointer + length + 1 >= len(self.buffer): + self.update(length + 1) while length: ch = self.buffer[self.pointer] self.pointer += 1 self.index += 1 - if ch in '\n\x85\u2028\u2029' \ - or (ch == '\r' and self.buffer[self.pointer] != '\n'): + if ch in "\n\x85\u2028\u2029" or ( + ch == "\r" and self.buffer[self.pointer] != "\n" + ): self.line += 1 self.column = 0 - elif ch != '\uFEFF': + elif ch != "\ufeff": self.column += 1 length -= 1 def get_mark(self): if self.stream is None: - return Mark(self.name, self.index, self.line, self.column, - self.buffer, self.pointer) + return Mark( + self.name, self.index, self.line, self.column, self.buffer, self.pointer + ) else: - return Mark(self.name, self.index, self.line, self.column, - None, None) + return Mark(self.name, self.index, self.line, self.column, None, None) def determine_encoding(self): while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): @@ -125,44 +137,56 @@ def determine_encoding(self): if isinstance(self.raw_buffer, bytes): if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): self.raw_decode = codecs.utf_16_le_decode - self.encoding = 'utf-16-le' + self.encoding = "utf-16-le" elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): self.raw_decode = codecs.utf_16_be_decode - self.encoding = 'utf-16-be' + self.encoding = "utf-16-be" else: self.raw_decode = codecs.utf_8_decode - self.encoding = 'utf-8' + self.encoding = "utf-8" self.update(1) - NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]') + NON_PRINTABLE = re.compile( + "[^\x09\x0a\x0d\x20-\x7e\x85\xa0-\ud7ff\ue000\ufffd\U00010000-\U0010ffff]" + ) + def check_printable(self, data): match = self.NON_PRINTABLE.search(data) if match: character = match.group() - position = self.index+(len(self.buffer)-self.pointer)+match.start() - raise ReaderError(self.name, position, ord(character), - 'unicode', "special characters are not allowed") + position = self.index + (len(self.buffer) - self.pointer) + match.start() + raise ReaderError( + self.name, + position, + ord(character), + "unicode", + "special characters are not allowed", + ) def update(self, length): if self.raw_buffer is None: return - self.buffer = self.buffer[self.pointer:] + self.buffer = self.buffer[self.pointer :] self.pointer = 0 while len(self.buffer) < length: if not self.eof: self.update_raw() if self.raw_decode is not None: try: - data, converted = self.raw_decode(self.raw_buffer, - 'strict', self.eof) + data, converted = self.raw_decode( + self.raw_buffer, "strict", self.eof + ) except UnicodeDecodeError as exc: character = self.raw_buffer[exc.start] if self.stream is not None: - position = self.stream_pointer-len(self.raw_buffer)+exc.start + position = ( + self.stream_pointer - len(self.raw_buffer) + exc.start + ) else: position = exc.start - raise ReaderError(self.name, position, character, - exc.encoding, exc.reason) + raise ReaderError( + self.name, position, character, exc.encoding, exc.reason + ) else: data = self.raw_buffer converted = len(data) @@ -170,7 +194,7 @@ def update(self, length): self.buffer += data self.raw_buffer = self.raw_buffer[converted:] if self.eof: - self.buffer += '\0' + self.buffer += "\0" self.raw_buffer = None break diff --git a/metaflow/_vendor/yaml/representer.py b/metaflow/_vendor/yaml/representer.py index 3b0b192ef32..86c6c7a9b9c 100644 --- a/metaflow/_vendor/yaml/representer.py +++ b/metaflow/_vendor/yaml/representer.py @@ -1,15 +1,15 @@ - -__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', - 'RepresenterError'] +__all__ = ["BaseRepresenter", "SafeRepresenter", "Representer", "RepresenterError"] from .error import * from .nodes import * import datetime, copyreg, types, base64, collections + class RepresenterError(YAMLError): pass + class BaseRepresenter: yaml_representers = {} @@ -38,10 +38,10 @@ def represent_data(self, data): if self.alias_key is not None: if self.alias_key in self.represented_objects: node = self.represented_objects[self.alias_key] - #if node is None: + # if node is None: # raise RepresenterError("recursive objects are not allowed: %r" % data) return node - #self.represented_objects[alias_key] = None + # self.represented_objects[alias_key] = None self.object_keeper.append(data) data_types = type(data).__mro__ if data_types[0] in self.yaml_representers: @@ -58,19 +58,19 @@ def represent_data(self, data): node = self.yaml_representers[None](self, data) else: node = ScalarNode(None, str(data)) - #if alias_key is not None: + # if alias_key is not None: # self.represented_objects[alias_key] = node return node @classmethod def add_representer(cls, data_type, representer): - if not 'yaml_representers' in cls.__dict__: + if not "yaml_representers" in cls.__dict__: cls.yaml_representers = cls.yaml_representers.copy() cls.yaml_representers[data_type] = representer @classmethod def add_multi_representer(cls, data_type, representer): - if not 'yaml_multi_representers' in cls.__dict__: + if not "yaml_multi_representers" in cls.__dict__: cls.yaml_multi_representers = cls.yaml_multi_representers.copy() cls.yaml_multi_representers[data_type] = representer @@ -106,7 +106,7 @@ def represent_mapping(self, tag, mapping, flow_style=None): if self.alias_key is not None: self.represented_objects[self.alias_key] = node best_style = True - if hasattr(mapping, 'items'): + if hasattr(mapping, "items"): mapping = list(mapping.items()) if self.sort_keys: try: @@ -131,6 +131,7 @@ def represent_mapping(self, tag, mapping, flow_style=None): def ignore_aliases(self, data): return False + class SafeRepresenter(BaseRepresenter): def ignore_aliases(self, data): @@ -142,39 +143,39 @@ def ignore_aliases(self, data): return True def represent_none(self, data): - return self.represent_scalar('tag:yaml.org,2002:null', 'null') + return self.represent_scalar("tag:yaml.org,2002:null", "null") def represent_str(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', data) + return self.represent_scalar("tag:yaml.org,2002:str", data) def represent_binary(self, data): - if hasattr(base64, 'encodebytes'): - data = base64.encodebytes(data).decode('ascii') + if hasattr(base64, "encodebytes"): + data = base64.encodebytes(data).decode("ascii") else: - data = base64.encodestring(data).decode('ascii') - return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|') + data = base64.encodestring(data).decode("ascii") + return self.represent_scalar("tag:yaml.org,2002:binary", data, style="|") def represent_bool(self, data): if data: - value = 'true' + value = "true" else: - value = 'false' - return self.represent_scalar('tag:yaml.org,2002:bool', value) + value = "false" + return self.represent_scalar("tag:yaml.org,2002:bool", value) def represent_int(self, data): - return self.represent_scalar('tag:yaml.org,2002:int', str(data)) + return self.represent_scalar("tag:yaml.org,2002:int", str(data)) inf_value = 1e300 - while repr(inf_value) != repr(inf_value*inf_value): + while repr(inf_value) != repr(inf_value * inf_value): inf_value *= inf_value def represent_float(self, data): if data != data or (data == 0.0 and data == 1.0): - value = '.nan' + value = ".nan" elif data == self.inf_value: - value = '.inf' + value = ".inf" elif data == -self.inf_value: - value = '-.inf' + value = "-.inf" else: value = repr(data).lower() # Note that in some cases `repr(data)` represents a float number @@ -184,44 +185,45 @@ def represent_float(self, data): # Unfortunately, this is not a valid float representation according # to the definition of the `!!float` tag. We fix this by adding # '.0' before the 'e' symbol. - if '.' not in value and 'e' in value: - value = value.replace('e', '.0e', 1) - return self.represent_scalar('tag:yaml.org,2002:float', value) + if "." not in value and "e" in value: + value = value.replace("e", ".0e", 1) + return self.represent_scalar("tag:yaml.org,2002:float", value) def represent_list(self, data): - #pairs = (len(data) > 0 and isinstance(data, list)) - #if pairs: + # pairs = (len(data) > 0 and isinstance(data, list)) + # if pairs: # for item in data: # if not isinstance(item, tuple) or len(item) != 2: # pairs = False # break - #if not pairs: - return self.represent_sequence('tag:yaml.org,2002:seq', data) - #value = [] - #for item_key, item_value in data: - # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', - # [(item_key, item_value)])) - #return SequenceNode(u'tag:yaml.org,2002:pairs', value) + # if not pairs: + return self.represent_sequence("tag:yaml.org,2002:seq", data) + + # value = [] + # for item_key, item_value in data: + # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', + # [(item_key, item_value)])) + # return SequenceNode(u'tag:yaml.org,2002:pairs', value) def represent_dict(self, data): - return self.represent_mapping('tag:yaml.org,2002:map', data) + return self.represent_mapping("tag:yaml.org,2002:map", data) def represent_set(self, data): value = {} for key in data: value[key] = None - return self.represent_mapping('tag:yaml.org,2002:set', value) + return self.represent_mapping("tag:yaml.org,2002:set", value) def represent_date(self, data): value = data.isoformat() - return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + return self.represent_scalar("tag:yaml.org,2002:timestamp", value) def represent_datetime(self, data): - value = data.isoformat(' ') - return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + value = data.isoformat(" ") + return self.represent_scalar("tag:yaml.org,2002:timestamp", value) def represent_yaml_object(self, tag, data, cls, flow_style=None): - if hasattr(data, '__getstate__'): + if hasattr(data, "__getstate__"): state = data.__getstate__() else: state = data.__dict__.copy() @@ -230,68 +232,58 @@ def represent_yaml_object(self, tag, data, cls, flow_style=None): def represent_undefined(self, data): raise RepresenterError("cannot represent an object", data) -SafeRepresenter.add_representer(type(None), - SafeRepresenter.represent_none) -SafeRepresenter.add_representer(str, - SafeRepresenter.represent_str) +SafeRepresenter.add_representer(type(None), SafeRepresenter.represent_none) + +SafeRepresenter.add_representer(str, SafeRepresenter.represent_str) + +SafeRepresenter.add_representer(bytes, SafeRepresenter.represent_binary) -SafeRepresenter.add_representer(bytes, - SafeRepresenter.represent_binary) +SafeRepresenter.add_representer(bool, SafeRepresenter.represent_bool) -SafeRepresenter.add_representer(bool, - SafeRepresenter.represent_bool) +SafeRepresenter.add_representer(int, SafeRepresenter.represent_int) -SafeRepresenter.add_representer(int, - SafeRepresenter.represent_int) +SafeRepresenter.add_representer(float, SafeRepresenter.represent_float) -SafeRepresenter.add_representer(float, - SafeRepresenter.represent_float) +SafeRepresenter.add_representer(list, SafeRepresenter.represent_list) -SafeRepresenter.add_representer(list, - SafeRepresenter.represent_list) +SafeRepresenter.add_representer(tuple, SafeRepresenter.represent_list) -SafeRepresenter.add_representer(tuple, - SafeRepresenter.represent_list) +SafeRepresenter.add_representer(dict, SafeRepresenter.represent_dict) -SafeRepresenter.add_representer(dict, - SafeRepresenter.represent_dict) +SafeRepresenter.add_representer(set, SafeRepresenter.represent_set) -SafeRepresenter.add_representer(set, - SafeRepresenter.represent_set) +SafeRepresenter.add_representer(datetime.date, SafeRepresenter.represent_date) -SafeRepresenter.add_representer(datetime.date, - SafeRepresenter.represent_date) +SafeRepresenter.add_representer(datetime.datetime, SafeRepresenter.represent_datetime) -SafeRepresenter.add_representer(datetime.datetime, - SafeRepresenter.represent_datetime) +SafeRepresenter.add_representer(None, SafeRepresenter.represent_undefined) -SafeRepresenter.add_representer(None, - SafeRepresenter.represent_undefined) class Representer(SafeRepresenter): def represent_complex(self, data): if data.imag == 0.0: - data = '%r' % data.real + data = "%r" % data.real elif data.real == 0.0: - data = '%rj' % data.imag + data = "%rj" % data.imag elif data.imag > 0: - data = '%r+%rj' % (data.real, data.imag) + data = "%r+%rj" % (data.real, data.imag) else: - data = '%r%rj' % (data.real, data.imag) - return self.represent_scalar('tag:yaml.org,2002:python/complex', data) + data = "%r%rj" % (data.real, data.imag) + return self.represent_scalar("tag:yaml.org,2002:python/complex", data) def represent_tuple(self, data): - return self.represent_sequence('tag:yaml.org,2002:python/tuple', data) + return self.represent_sequence("tag:yaml.org,2002:python/tuple", data) def represent_name(self, data): - name = '%s.%s' % (data.__module__, data.__name__) - return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '') + name = "%s.%s" % (data.__module__, data.__name__) + return self.represent_scalar("tag:yaml.org,2002:python/name:" + name, "") def represent_module(self, data): return self.represent_scalar( - 'tag:yaml.org,2002:python/module:'+data.__name__, '') + "tag:yaml.org,2002:python/module:" + data.__name__, "" + ) def represent_object(self, data): # We use __reduce__ API to save the data. data.__reduce__ returns @@ -313,13 +305,13 @@ def represent_object(self, data): cls = type(data) if cls in copyreg.dispatch_table: reduce = copyreg.dispatch_table[cls](data) - elif hasattr(data, '__reduce_ex__'): + elif hasattr(data, "__reduce_ex__"): reduce = data.__reduce_ex__(2) - elif hasattr(data, '__reduce__'): + elif hasattr(data, "__reduce__"): reduce = data.__reduce__() else: raise RepresenterError("cannot represent an object", data) - reduce = (list(reduce)+[None]*5)[:5] + reduce = (list(reduce) + [None] * 5)[:5] function, args, state, listitems, dictitems = reduce args = list(args) if state is None: @@ -328,62 +320,61 @@ def represent_object(self, data): listitems = list(listitems) if dictitems is not None: dictitems = dict(dictitems) - if function.__name__ == '__newobj__': + if function.__name__ == "__newobj__": function = args[0] args = args[1:] - tag = 'tag:yaml.org,2002:python/object/new:' + tag = "tag:yaml.org,2002:python/object/new:" newobj = True else: - tag = 'tag:yaml.org,2002:python/object/apply:' + tag = "tag:yaml.org,2002:python/object/apply:" newobj = False - function_name = '%s.%s' % (function.__module__, function.__name__) - if not args and not listitems and not dictitems \ - and isinstance(state, dict) and newobj: + function_name = "%s.%s" % (function.__module__, function.__name__) + if ( + not args + and not listitems + and not dictitems + and isinstance(state, dict) + and newobj + ): return self.represent_mapping( - 'tag:yaml.org,2002:python/object:'+function_name, state) - if not listitems and not dictitems \ - and isinstance(state, dict) and not state: - return self.represent_sequence(tag+function_name, args) + "tag:yaml.org,2002:python/object:" + function_name, state + ) + if not listitems and not dictitems and isinstance(state, dict) and not state: + return self.represent_sequence(tag + function_name, args) value = {} if args: - value['args'] = args + value["args"] = args if state or not isinstance(state, dict): - value['state'] = state + value["state"] = state if listitems: - value['listitems'] = listitems + value["listitems"] = listitems if dictitems: - value['dictitems'] = dictitems - return self.represent_mapping(tag+function_name, value) + value["dictitems"] = dictitems + return self.represent_mapping(tag + function_name, value) def represent_ordered_dict(self, data): # Provide uniform representation across different Python versions. data_type = type(data) - tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \ - % (data_type.__module__, data_type.__name__) + tag = "tag:yaml.org,2002:python/object/apply:%s.%s" % ( + data_type.__module__, + data_type.__name__, + ) items = [[key, value] for key, value in data.items()] return self.represent_sequence(tag, [items]) -Representer.add_representer(complex, - Representer.represent_complex) -Representer.add_representer(tuple, - Representer.represent_tuple) +Representer.add_representer(complex, Representer.represent_complex) -Representer.add_representer(type, - Representer.represent_name) +Representer.add_representer(tuple, Representer.represent_tuple) -Representer.add_representer(collections.OrderedDict, - Representer.represent_ordered_dict) +Representer.add_representer(type, Representer.represent_name) -Representer.add_representer(types.FunctionType, - Representer.represent_name) +Representer.add_representer(collections.OrderedDict, Representer.represent_ordered_dict) -Representer.add_representer(types.BuiltinFunctionType, - Representer.represent_name) +Representer.add_representer(types.FunctionType, Representer.represent_name) -Representer.add_representer(types.ModuleType, - Representer.represent_module) +Representer.add_representer(types.BuiltinFunctionType, Representer.represent_name) -Representer.add_multi_representer(object, - Representer.represent_object) +Representer.add_representer(types.ModuleType, Representer.represent_module) +Representer.add_multi_representer(object, Representer.represent_object) diff --git a/metaflow/_vendor/yaml/resolver.py b/metaflow/_vendor/yaml/resolver.py index 02b82e73eec..1b1c81b262a 100644 --- a/metaflow/_vendor/yaml/resolver.py +++ b/metaflow/_vendor/yaml/resolver.py @@ -1,19 +1,20 @@ - -__all__ = ['BaseResolver', 'Resolver'] +__all__ = ["BaseResolver", "Resolver"] from .error import * from .nodes import * import re + class ResolverError(YAMLError): pass + class BaseResolver: - DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str' - DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq' - DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map' + DEFAULT_SCALAR_TAG = "tag:yaml.org,2002:str" + DEFAULT_SEQUENCE_TAG = "tag:yaml.org,2002:seq" + DEFAULT_MAPPING_TAG = "tag:yaml.org,2002:map" yaml_implicit_resolvers = {} yaml_path_resolvers = {} @@ -24,7 +25,7 @@ def __init__(self): @classmethod def add_implicit_resolver(cls, tag, regexp, first): - if not 'yaml_implicit_resolvers' in cls.__dict__: + if not "yaml_implicit_resolvers" in cls.__dict__: implicit_resolvers = {} for key in cls.yaml_implicit_resolvers: implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:] @@ -48,7 +49,7 @@ def add_path_resolver(cls, tag, path, kind=None): # a mapping value that corresponds to a scalar key which content is # equal to the `index_check` value. An integer `index_check` matches # against a sequence value with the index equal to `index_check`. - if not 'yaml_path_resolvers' in cls.__dict__: + if not "yaml_path_resolvers" in cls.__dict__: cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() new_path = [] for element in path: @@ -69,12 +70,13 @@ def add_path_resolver(cls, tag, path, kind=None): node_check = SequenceNode elif node_check is dict: node_check = MappingNode - elif node_check not in [ScalarNode, SequenceNode, MappingNode] \ - and not isinstance(node_check, str) \ - and node_check is not None: + elif ( + node_check not in [ScalarNode, SequenceNode, MappingNode] + and not isinstance(node_check, str) + and node_check is not None + ): raise ResolverError("Invalid node checker: %s" % node_check) - if not isinstance(index_check, (str, int)) \ - and index_check is not None: + if not isinstance(index_check, (str, int)) and index_check is not None: raise ResolverError("Invalid index checker: %s" % index_check) new_path.append((node_check, index_check)) if kind is str: @@ -83,8 +85,7 @@ def add_path_resolver(cls, tag, path, kind=None): kind = SequenceNode elif kind is dict: kind = MappingNode - elif kind not in [ScalarNode, SequenceNode, MappingNode] \ - and kind is not None: + elif kind not in [ScalarNode, SequenceNode, MappingNode] and kind is not None: raise ResolverError("Invalid node kind: %s" % kind) cls.yaml_path_resolvers[tuple(new_path), kind] = tag @@ -96,8 +97,9 @@ def descend_resolver(self, current_node, current_index): if current_node: depth = len(self.resolver_prefix_paths) for path, kind in self.resolver_prefix_paths[-1]: - if self.check_resolver_prefix(depth, path, kind, - current_node, current_index): + if self.check_resolver_prefix( + depth, path, kind, current_node, current_index + ): if len(path) > depth: prefix_paths.append((path, kind)) else: @@ -117,9 +119,8 @@ def ascend_resolver(self): self.resolver_exact_paths.pop() self.resolver_prefix_paths.pop() - def check_resolver_prefix(self, depth, path, kind, - current_node, current_index): - node_check, index_check = path[depth-1] + def check_resolver_prefix(self, depth, path, kind, current_node, current_index): + node_check, index_check = path[depth - 1] if isinstance(node_check, str): if current_node.tag != node_check: return @@ -128,12 +129,13 @@ def check_resolver_prefix(self, depth, path, kind, return if index_check is True and current_index is not None: return - if (index_check is False or index_check is None) \ - and current_index is None: + if (index_check is False or index_check is None) and current_index is None: return if isinstance(index_check, str): - if not (isinstance(current_index, ScalarNode) - and index_check == current_index.value): + if not ( + isinstance(current_index, ScalarNode) + and index_check == current_index.value + ): return elif isinstance(index_check, int) and not isinstance(index_check, bool): if index_check != current_index: @@ -142,8 +144,8 @@ def check_resolver_prefix(self, depth, path, kind, def resolve(self, kind, value, implicit): if kind is ScalarNode and implicit[0]: - if value == '': - resolvers = self.yaml_implicit_resolvers.get('', []) + if value == "": + resolvers = self.yaml_implicit_resolvers.get("", []) else: resolvers = self.yaml_implicit_resolvers.get(value[0], []) resolvers += self.yaml_implicit_resolvers.get(None, []) @@ -164,64 +166,80 @@ def resolve(self, kind, value, implicit): elif kind is MappingNode: return self.DEFAULT_MAPPING_TAG + class Resolver(BaseResolver): pass + Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:bool', - re.compile(r'''^(?:yes|Yes|YES|no|No|NO + "tag:yaml.org,2002:bool", + re.compile( + r"""^(?:yes|Yes|YES|no|No|NO |true|True|TRUE|false|False|FALSE - |on|On|ON|off|Off|OFF)$''', re.X), - list('yYnNtTfFoO')) + |on|On|ON|off|Off|OFF)$""", + re.X, + ), + list("yYnNtTfFoO"), +) Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:float', - re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)? + "tag:yaml.org,2002:float", + re.compile( + r"""^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)? |\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]* |[-+]?\.(?:inf|Inf|INF) - |\.(?:nan|NaN|NAN))$''', re.X), - list('-+0123456789.')) + |\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), +) Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:int', - re.compile(r'''^(?:[-+]?0b[0-1_]+ + "tag:yaml.org,2002:int", + re.compile( + r"""^(?:[-+]?0b[0-1_]+ |[-+]?0[0-7_]+ |[-+]?(?:0|[1-9][0-9_]*) |[-+]?0x[0-9a-fA-F_]+ - |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), - list('-+0123456789')) + |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$""", + re.X, + ), + list("-+0123456789"), +) Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:merge', - re.compile(r'^(?:<<)$'), - ['<']) + "tag:yaml.org,2002:merge", re.compile(r"^(?:<<)$"), ["<"] +) Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:null', - re.compile(r'''^(?: ~ + "tag:yaml.org,2002:null", + re.compile( + r"""^(?: ~ |null|Null|NULL - | )$''', re.X), - ['~', 'n', 'N', '']) + | )$""", + re.X, + ), + ["~", "n", "N", ""], +) Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:timestamp', - re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] + "tag:yaml.org,2002:timestamp", + re.compile( + r"""^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]? (?:[Tt]|[ \t]+)[0-9][0-9]? :[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)? - (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X), - list('0123456789')) + (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$""", + re.X, + ), + list("0123456789"), +) -Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:value', - re.compile(r'^(?:=)$'), - ['=']) +Resolver.add_implicit_resolver("tag:yaml.org,2002:value", re.compile(r"^(?:=)$"), ["="]) # The following resolver is only for documentation purposes. It cannot work # because plain scalars cannot start with '!', '&', or '*'. Resolver.add_implicit_resolver( - 'tag:yaml.org,2002:yaml', - re.compile(r'^(?:!|&|\*)$'), - list('!&*')) - + "tag:yaml.org,2002:yaml", re.compile(r"^(?:!|&|\*)$"), list("!&*") +) diff --git a/metaflow/_vendor/yaml/scanner.py b/metaflow/_vendor/yaml/scanner.py index 7437ede1c60..31711dc8945 100644 --- a/metaflow/_vendor/yaml/scanner.py +++ b/metaflow/_vendor/yaml/scanner.py @@ -1,4 +1,3 @@ - # Scanner produces tokens of the following types: # STREAM-START # STREAM-END @@ -24,14 +23,16 @@ # Read comments in the Scanner code for more details. # -__all__ = ['Scanner', 'ScannerError'] +__all__ = ["Scanner", "ScannerError"] from .error import MarkedYAMLError from .tokens import * + class ScannerError(MarkedYAMLError): pass + class SimpleKey: # See below simple keys treatment. @@ -43,6 +44,7 @@ def __init__(self, token_number, required, index, line, column, mark): self.column = column self.mark = mark + class Scanner: def __init__(self): @@ -169,85 +171,85 @@ def fetch_more_tokens(self): ch = self.peek() # Is it the end of stream? - if ch == '\0': + if ch == "\0": return self.fetch_stream_end() # Is it a directive? - if ch == '%' and self.check_directive(): + if ch == "%" and self.check_directive(): return self.fetch_directive() # Is it the document start? - if ch == '-' and self.check_document_start(): + if ch == "-" and self.check_document_start(): return self.fetch_document_start() # Is it the document end? - if ch == '.' and self.check_document_end(): + if ch == "." and self.check_document_end(): return self.fetch_document_end() # TODO: support for BOM within a stream. - #if ch == '\uFEFF': + # if ch == '\uFEFF': # return self.fetch_bom() <-- issue BOMToken # Note: the order of the following checks is NOT significant. # Is it the flow sequence start indicator? - if ch == '[': + if ch == "[": return self.fetch_flow_sequence_start() # Is it the flow mapping start indicator? - if ch == '{': + if ch == "{": return self.fetch_flow_mapping_start() # Is it the flow sequence end indicator? - if ch == ']': + if ch == "]": return self.fetch_flow_sequence_end() # Is it the flow mapping end indicator? - if ch == '}': + if ch == "}": return self.fetch_flow_mapping_end() # Is it the flow entry indicator? - if ch == ',': + if ch == ",": return self.fetch_flow_entry() # Is it the block entry indicator? - if ch == '-' and self.check_block_entry(): + if ch == "-" and self.check_block_entry(): return self.fetch_block_entry() # Is it the key indicator? - if ch == '?' and self.check_key(): + if ch == "?" and self.check_key(): return self.fetch_key() # Is it the value indicator? - if ch == ':' and self.check_value(): + if ch == ":" and self.check_value(): return self.fetch_value() # Is it an alias? - if ch == '*': + if ch == "*": return self.fetch_alias() # Is it an anchor? - if ch == '&': + if ch == "&": return self.fetch_anchor() # Is it a tag? - if ch == '!': + if ch == "!": return self.fetch_tag() # Is it a literal scalar? - if ch == '|' and not self.flow_level: + if ch == "|" and not self.flow_level: return self.fetch_literal() # Is it a folded scalar? - if ch == '>' and not self.flow_level: + if ch == ">" and not self.flow_level: return self.fetch_folded() # Is it a single quoted scalar? - if ch == '\'': + if ch == "'": return self.fetch_single() # Is it a double quoted scalar? - if ch == '\"': + if ch == '"': return self.fetch_double() # It must be a plain scalar then. @@ -255,9 +257,12 @@ def fetch_more_tokens(self): return self.fetch_plain() # No? It's an error. Let's produce a nice error message. - raise ScannerError("while scanning for the next token", None, - "found character %r that cannot start any token" % ch, - self.get_mark()) + raise ScannerError( + "while scanning for the next token", + None, + "found character %r that cannot start any token" % ch, + self.get_mark(), + ) # Simple keys treatment. @@ -285,11 +290,14 @@ def stale_possible_simple_keys(self): # height (may cause problems if indentation is broken though). for level in list(self.possible_simple_keys): key = self.possible_simple_keys[level] - if key.line != self.line \ - or self.index-key.index > 1024: + if key.line != self.line or self.index - key.index > 1024: if key.required: - raise ScannerError("while scanning a simple key", key.mark, - "could not find expected ':'", self.get_mark()) + raise ScannerError( + "while scanning a simple key", + key.mark, + "could not find expected ':'", + self.get_mark(), + ) del self.possible_simple_keys[level] def save_possible_simple_key(self): @@ -304,19 +312,29 @@ def save_possible_simple_key(self): # position. if self.allow_simple_key: self.remove_possible_simple_key() - token_number = self.tokens_taken+len(self.tokens) - key = SimpleKey(token_number, required, - self.index, self.line, self.column, self.get_mark()) + token_number = self.tokens_taken + len(self.tokens) + key = SimpleKey( + token_number, + required, + self.index, + self.line, + self.column, + self.get_mark(), + ) self.possible_simple_keys[self.flow_level] = key def remove_possible_simple_key(self): # Remove the saved possible key position at the current flow level. if self.flow_level in self.possible_simple_keys: key = self.possible_simple_keys[self.flow_level] - + if key.required: - raise ScannerError("while scanning a simple key", key.mark, - "could not find expected ':'", self.get_mark()) + raise ScannerError( + "while scanning a simple key", + key.mark, + "could not find expected ':'", + self.get_mark(), + ) del self.possible_simple_keys[self.flow_level] @@ -330,7 +348,7 @@ def unwind_indent(self, column): ## constructions such as ## key : { ## } - #if self.flow_level and self.indent > column: + # if self.flow_level and self.indent > column: # raise ScannerError(None, None, # "invalid indentation or unclosed '[' or '{'", # self.get_mark()) @@ -362,11 +380,9 @@ def fetch_stream_start(self): # Read the token. mark = self.get_mark() - + # Add STREAM-START. - self.tokens.append(StreamStartToken(mark, mark, - encoding=self.encoding)) - + self.tokens.append(StreamStartToken(mark, mark, encoding=self.encoding)) def fetch_stream_end(self): @@ -380,7 +396,7 @@ def fetch_stream_end(self): # Read the token. mark = self.get_mark() - + # Add STREAM-END. self.tokens.append(StreamEndToken(mark, mark)) @@ -388,7 +404,7 @@ def fetch_stream_end(self): self.done = True def fetch_directive(self): - + # Set the current indentation to -1. self.unwind_indent(-1) @@ -488,9 +504,9 @@ def fetch_block_entry(self): # Are we allowed to start a new entry? if not self.allow_simple_key: - raise ScannerError(None, None, - "sequence entries are not allowed here", - self.get_mark()) + raise ScannerError( + None, None, "sequence entries are not allowed here", self.get_mark() + ) # We may need to add BLOCK-SEQUENCE-START. if self.add_indent(self.column): @@ -515,15 +531,15 @@ def fetch_block_entry(self): self.tokens.append(BlockEntryToken(start_mark, end_mark)) def fetch_key(self): - + # Block context needs additional checks. if not self.flow_level: # Are we allowed to start a key (not necessary a simple)? if not self.allow_simple_key: - raise ScannerError(None, None, - "mapping keys are not allowed here", - self.get_mark()) + raise ScannerError( + None, None, "mapping keys are not allowed here", self.get_mark() + ) # We may need to add BLOCK-MAPPING-START. if self.add_indent(self.column): @@ -550,22 +566,25 @@ def fetch_value(self): # Add KEY. key = self.possible_simple_keys[self.flow_level] del self.possible_simple_keys[self.flow_level] - self.tokens.insert(key.token_number-self.tokens_taken, - KeyToken(key.mark, key.mark)) + self.tokens.insert( + key.token_number - self.tokens_taken, KeyToken(key.mark, key.mark) + ) # If this key starts a new block mapping, we need to add # BLOCK-MAPPING-START. if not self.flow_level: if self.add_indent(key.column): - self.tokens.insert(key.token_number-self.tokens_taken, - BlockMappingStartToken(key.mark, key.mark)) + self.tokens.insert( + key.token_number - self.tokens_taken, + BlockMappingStartToken(key.mark, key.mark), + ) # There cannot be two simple keys one after another. self.allow_simple_key = False # It must be a part of a complex key. else: - + # Block context needs additional checks. # (Do we really need them? They will be caught by the parser # anyway.) @@ -574,9 +593,12 @@ def fetch_value(self): # We are allowed to start a complex value if and only if # we can start a simple key. if not self.allow_simple_key: - raise ScannerError(None, None, - "mapping values are not allowed here", - self.get_mark()) + raise ScannerError( + None, + None, + "mapping values are not allowed here", + self.get_mark(), + ) # If this value starts a new block mapping, we need to add # BLOCK-MAPPING-START. It will be detected as an error later by @@ -632,10 +654,10 @@ def fetch_tag(self): self.tokens.append(self.scan_tag()) def fetch_literal(self): - self.fetch_block_scalar(style='|') + self.fetch_block_scalar(style="|") def fetch_folded(self): - self.fetch_block_scalar(style='>') + self.fetch_block_scalar(style=">") def fetch_block_scalar(self, style): @@ -649,7 +671,7 @@ def fetch_block_scalar(self, style): self.tokens.append(self.scan_block_scalar(style)) def fetch_single(self): - self.fetch_flow_scalar(style='\'') + self.fetch_flow_scalar(style="'") def fetch_double(self): self.fetch_flow_scalar(style='"') @@ -691,22 +713,20 @@ def check_document_start(self): # DOCUMENT-START: ^ '---' (' '|'\n') if self.column == 0: - if self.prefix(3) == '---' \ - and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + if self.prefix(3) == "---" and self.peek(3) in "\0 \t\r\n\x85\u2028\u2029": return True def check_document_end(self): # DOCUMENT-END: ^ '...' (' '|'\n') if self.column == 0: - if self.prefix(3) == '...' \ - and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + if self.prefix(3) == "..." and self.peek(3) in "\0 \t\r\n\x85\u2028\u2029": return True def check_block_entry(self): # BLOCK-ENTRY: '-' (' '|'\n') - return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + return self.peek(1) in "\0 \t\r\n\x85\u2028\u2029" def check_key(self): @@ -716,7 +736,7 @@ def check_key(self): # KEY(block context): '?' (' '|'\n') else: - return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + return self.peek(1) in "\0 \t\r\n\x85\u2028\u2029" def check_value(self): @@ -726,7 +746,7 @@ def check_value(self): # VALUE(block context): ':' (' '|'\n') else: - return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + return self.peek(1) in "\0 \t\r\n\x85\u2028\u2029" def check_plain(self): @@ -743,9 +763,10 @@ def check_plain(self): # '-' character) because we want the flow context to be space # independent. ch = self.peek() - return ch not in '\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>\'\"%@`' \ - or (self.peek(1) not in '\0 \t\r\n\x85\u2028\u2029' - and (ch == '-' or (not self.flow_level and ch in '?:'))) + return ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`" or ( + self.peek(1) not in "\0 \t\r\n\x85\u2028\u2029" + and (ch == "-" or (not self.flow_level and ch in "?:")) + ) # Scanners. @@ -769,14 +790,14 @@ def scan_to_next_token(self): # `unwind_indent` before issuing BLOCK-END. # Scanners for block, flow, and plain scalars need to be modified. - if self.index == 0 and self.peek() == '\uFEFF': + if self.index == 0 and self.peek() == "\ufeff": self.forward() found = False while not found: - while self.peek() == ' ': + while self.peek() == " ": self.forward() - if self.peek() == '#': - while self.peek() not in '\0\r\n\x85\u2028\u2029': + if self.peek() == "#": + while self.peek() not in "\0\r\n\x85\u2028\u2029": self.forward() if self.scan_line_break(): if not self.flow_level: @@ -790,15 +811,15 @@ def scan_directive(self): self.forward() name = self.scan_directive_name(start_mark) value = None - if name == 'YAML': + if name == "YAML": value = self.scan_yaml_directive_value(start_mark) end_mark = self.get_mark() - elif name == 'TAG': + elif name == "TAG": value = self.scan_tag_directive_value(start_mark) end_mark = self.get_mark() else: end_mark = self.get_mark() - while self.peek() not in '\0\r\n\x85\u2028\u2029': + while self.peek() not in "\0\r\n\x85\u2028\u2029": self.forward() self.scan_directive_ignored_line(start_mark) return DirectiveToken(name, value, start_mark, end_mark) @@ -807,48 +828,63 @@ def scan_directive_name(self, start_mark): # See the specification for details. length = 0 ch = self.peek(length) - while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-_': + while "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_": length += 1 ch = self.peek(length) if not length: - raise ScannerError("while scanning a directive", start_mark, - "expected alphabetic or numeric character, but found %r" - % ch, self.get_mark()) + raise ScannerError( + "while scanning a directive", + start_mark, + "expected alphabetic or numeric character, but found %r" % ch, + self.get_mark(), + ) value = self.prefix(length) self.forward(length) ch = self.peek() - if ch not in '\0 \r\n\x85\u2028\u2029': - raise ScannerError("while scanning a directive", start_mark, - "expected alphabetic or numeric character, but found %r" - % ch, self.get_mark()) + if ch not in "\0 \r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected alphabetic or numeric character, but found %r" % ch, + self.get_mark(), + ) return value def scan_yaml_directive_value(self, start_mark): # See the specification for details. - while self.peek() == ' ': + while self.peek() == " ": self.forward() major = self.scan_yaml_directive_number(start_mark) - if self.peek() != '.': - raise ScannerError("while scanning a directive", start_mark, - "expected a digit or '.', but found %r" % self.peek(), - self.get_mark()) + if self.peek() != ".": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected a digit or '.', but found %r" % self.peek(), + self.get_mark(), + ) self.forward() minor = self.scan_yaml_directive_number(start_mark) - if self.peek() not in '\0 \r\n\x85\u2028\u2029': - raise ScannerError("while scanning a directive", start_mark, - "expected a digit or ' ', but found %r" % self.peek(), - self.get_mark()) + if self.peek() not in "\0 \r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected a digit or ' ', but found %r" % self.peek(), + self.get_mark(), + ) return (major, minor) def scan_yaml_directive_number(self, start_mark): # See the specification for details. ch = self.peek() - if not ('0' <= ch <= '9'): - raise ScannerError("while scanning a directive", start_mark, - "expected a digit, but found %r" % ch, self.get_mark()) + if not ("0" <= ch <= "9"): + raise ScannerError( + "while scanning a directive", + start_mark, + "expected a digit, but found %r" % ch, + self.get_mark(), + ) length = 0 - while '0' <= self.peek(length) <= '9': + while "0" <= self.peek(length) <= "9": length += 1 value = int(self.prefix(length)) self.forward(length) @@ -856,44 +892,55 @@ def scan_yaml_directive_number(self, start_mark): def scan_tag_directive_value(self, start_mark): # See the specification for details. - while self.peek() == ' ': + while self.peek() == " ": self.forward() handle = self.scan_tag_directive_handle(start_mark) - while self.peek() == ' ': + while self.peek() == " ": self.forward() prefix = self.scan_tag_directive_prefix(start_mark) return (handle, prefix) def scan_tag_directive_handle(self, start_mark): # See the specification for details. - value = self.scan_tag_handle('directive', start_mark) + value = self.scan_tag_handle("directive", start_mark) ch = self.peek() - if ch != ' ': - raise ScannerError("while scanning a directive", start_mark, - "expected ' ', but found %r" % ch, self.get_mark()) + if ch != " ": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected ' ', but found %r" % ch, + self.get_mark(), + ) return value def scan_tag_directive_prefix(self, start_mark): # See the specification for details. - value = self.scan_tag_uri('directive', start_mark) + value = self.scan_tag_uri("directive", start_mark) ch = self.peek() - if ch not in '\0 \r\n\x85\u2028\u2029': - raise ScannerError("while scanning a directive", start_mark, - "expected ' ', but found %r" % ch, self.get_mark()) + if ch not in "\0 \r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected ' ', but found %r" % ch, + self.get_mark(), + ) return value def scan_directive_ignored_line(self, start_mark): # See the specification for details. - while self.peek() == ' ': + while self.peek() == " ": self.forward() - if self.peek() == '#': - while self.peek() not in '\0\r\n\x85\u2028\u2029': + if self.peek() == "#": + while self.peek() not in "\0\r\n\x85\u2028\u2029": self.forward() ch = self.peek() - if ch not in '\0\r\n\x85\u2028\u2029': - raise ScannerError("while scanning a directive", start_mark, - "expected a comment or a line break, but found %r" - % ch, self.get_mark()) + if ch not in "\0\r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a directive", + start_mark, + "expected a comment or a line break, but found %r" % ch, + self.get_mark(), + ) self.scan_line_break() def scan_anchor(self, TokenClass): @@ -907,28 +954,33 @@ def scan_anchor(self, TokenClass): # Therefore we restrict aliases to numbers and ASCII letters. start_mark = self.get_mark() indicator = self.peek() - if indicator == '*': - name = 'alias' + if indicator == "*": + name = "alias" else: - name = 'anchor' + name = "anchor" self.forward() length = 0 ch = self.peek(length) - while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-_': + while "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_": length += 1 ch = self.peek(length) if not length: - raise ScannerError("while scanning an %s" % name, start_mark, - "expected alphabetic or numeric character, but found %r" - % ch, self.get_mark()) + raise ScannerError( + "while scanning an %s" % name, + start_mark, + "expected alphabetic or numeric character, but found %r" % ch, + self.get_mark(), + ) value = self.prefix(length) self.forward(length) ch = self.peek() - if ch not in '\0 \t\r\n\x85\u2028\u2029?:,]}%@`': - raise ScannerError("while scanning an %s" % name, start_mark, - "expected alphabetic or numeric character, but found %r" - % ch, self.get_mark()) + if ch not in "\0 \t\r\n\x85\u2028\u2029?:,]}%@`": + raise ScannerError( + "while scanning an %s" % name, + start_mark, + "expected alphabetic or numeric character, but found %r" % ch, + self.get_mark(), + ) end_mark = self.get_mark() return TokenClass(value, start_mark, end_mark) @@ -936,39 +988,46 @@ def scan_tag(self): # See the specification for details. start_mark = self.get_mark() ch = self.peek(1) - if ch == '<': + if ch == "<": handle = None self.forward(2) - suffix = self.scan_tag_uri('tag', start_mark) - if self.peek() != '>': - raise ScannerError("while parsing a tag", start_mark, - "expected '>', but found %r" % self.peek(), - self.get_mark()) + suffix = self.scan_tag_uri("tag", start_mark) + if self.peek() != ">": + raise ScannerError( + "while parsing a tag", + start_mark, + "expected '>', but found %r" % self.peek(), + self.get_mark(), + ) self.forward() - elif ch in '\0 \t\r\n\x85\u2028\u2029': + elif ch in "\0 \t\r\n\x85\u2028\u2029": handle = None - suffix = '!' + suffix = "!" self.forward() else: length = 1 use_handle = False - while ch not in '\0 \r\n\x85\u2028\u2029': - if ch == '!': + while ch not in "\0 \r\n\x85\u2028\u2029": + if ch == "!": use_handle = True break length += 1 ch = self.peek(length) - handle = '!' + handle = "!" if use_handle: - handle = self.scan_tag_handle('tag', start_mark) + handle = self.scan_tag_handle("tag", start_mark) else: - handle = '!' + handle = "!" self.forward() - suffix = self.scan_tag_uri('tag', start_mark) + suffix = self.scan_tag_uri("tag", start_mark) ch = self.peek() - if ch not in '\0 \r\n\x85\u2028\u2029': - raise ScannerError("while scanning a tag", start_mark, - "expected ' ', but found %r" % ch, self.get_mark()) + if ch not in "\0 \r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a tag", + start_mark, + "expected ' ', but found %r" % ch, + self.get_mark(), + ) value = (handle, suffix) end_mark = self.get_mark() return TagToken(value, start_mark, end_mark) @@ -976,7 +1035,7 @@ def scan_tag(self): def scan_block_scalar(self, style): # See the specification for details. - if style == '>': + if style == ">": folded = True else: folded = False @@ -990,51 +1049,55 @@ def scan_block_scalar(self, style): self.scan_block_scalar_ignored_line(start_mark) # Determine the indentation level and go to the first non-empty line. - min_indent = self.indent+1 + min_indent = self.indent + 1 if min_indent < 1: min_indent = 1 if increment is None: breaks, max_indent, end_mark = self.scan_block_scalar_indentation() indent = max(min_indent, max_indent) else: - indent = min_indent+increment-1 + indent = min_indent + increment - 1 breaks, end_mark = self.scan_block_scalar_breaks(indent) - line_break = '' + line_break = "" # Scan the inner part of the block scalar. - while self.column == indent and self.peek() != '\0': + while self.column == indent and self.peek() != "\0": chunks.extend(breaks) - leading_non_space = self.peek() not in ' \t' + leading_non_space = self.peek() not in " \t" length = 0 - while self.peek(length) not in '\0\r\n\x85\u2028\u2029': + while self.peek(length) not in "\0\r\n\x85\u2028\u2029": length += 1 chunks.append(self.prefix(length)) self.forward(length) line_break = self.scan_line_break() breaks, end_mark = self.scan_block_scalar_breaks(indent) - if self.column == indent and self.peek() != '\0': + if self.column == indent and self.peek() != "\0": # Unfortunately, folding rules are ambiguous. # # This is the folding according to the specification: - - if folded and line_break == '\n' \ - and leading_non_space and self.peek() not in ' \t': + + if ( + folded + and line_break == "\n" + and leading_non_space + and self.peek() not in " \t" + ): if not breaks: - chunks.append(' ') + chunks.append(" ") else: chunks.append(line_break) - + # This is Clark Evans's interpretation (also in the spec # examples): # - #if folded and line_break == '\n': + # if folded and line_break == '\n': # if not breaks: # if self.peek() not in ' \t': # chunks.append(' ') # else: # chunks.append(line_break) - #else: + # else: # chunks.append(line_break) else: break @@ -1046,61 +1109,72 @@ def scan_block_scalar(self, style): chunks.extend(breaks) # We are done. - return ScalarToken(''.join(chunks), False, start_mark, end_mark, - style) + return ScalarToken("".join(chunks), False, start_mark, end_mark, style) def scan_block_scalar_indicators(self, start_mark): # See the specification for details. chomping = None increment = None ch = self.peek() - if ch in '+-': - if ch == '+': + if ch in "+-": + if ch == "+": chomping = True else: chomping = False self.forward() ch = self.peek() - if ch in '0123456789': + if ch in "0123456789": increment = int(ch) if increment == 0: - raise ScannerError("while scanning a block scalar", start_mark, - "expected indentation indicator in the range 1-9, but found 0", - self.get_mark()) + raise ScannerError( + "while scanning a block scalar", + start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark(), + ) self.forward() - elif ch in '0123456789': + elif ch in "0123456789": increment = int(ch) if increment == 0: - raise ScannerError("while scanning a block scalar", start_mark, - "expected indentation indicator in the range 1-9, but found 0", - self.get_mark()) + raise ScannerError( + "while scanning a block scalar", + start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark(), + ) self.forward() ch = self.peek() - if ch in '+-': - if ch == '+': + if ch in "+-": + if ch == "+": chomping = True else: chomping = False self.forward() ch = self.peek() - if ch not in '\0 \r\n\x85\u2028\u2029': - raise ScannerError("while scanning a block scalar", start_mark, - "expected chomping or indentation indicators, but found %r" - % ch, self.get_mark()) + if ch not in "\0 \r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a block scalar", + start_mark, + "expected chomping or indentation indicators, but found %r" % ch, + self.get_mark(), + ) return chomping, increment def scan_block_scalar_ignored_line(self, start_mark): # See the specification for details. - while self.peek() == ' ': + while self.peek() == " ": self.forward() - if self.peek() == '#': - while self.peek() not in '\0\r\n\x85\u2028\u2029': + if self.peek() == "#": + while self.peek() not in "\0\r\n\x85\u2028\u2029": self.forward() ch = self.peek() - if ch not in '\0\r\n\x85\u2028\u2029': - raise ScannerError("while scanning a block scalar", start_mark, - "expected a comment or a line break, but found %r" % ch, - self.get_mark()) + if ch not in "\0\r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a block scalar", + start_mark, + "expected a comment or a line break, but found %r" % ch, + self.get_mark(), + ) self.scan_line_break() def scan_block_scalar_indentation(self): @@ -1108,8 +1182,8 @@ def scan_block_scalar_indentation(self): chunks = [] max_indent = 0 end_mark = self.get_mark() - while self.peek() in ' \r\n\x85\u2028\u2029': - if self.peek() != ' ': + while self.peek() in " \r\n\x85\u2028\u2029": + if self.peek() != " ": chunks.append(self.scan_line_break()) end_mark = self.get_mark() else: @@ -1122,12 +1196,12 @@ def scan_block_scalar_breaks(self, indent): # See the specification for details. chunks = [] end_mark = self.get_mark() - while self.column < indent and self.peek() == ' ': + while self.column < indent and self.peek() == " ": self.forward() - while self.peek() in '\r\n\x85\u2028\u2029': + while self.peek() in "\r\n\x85\u2028\u2029": chunks.append(self.scan_line_break()) end_mark = self.get_mark() - while self.column < indent and self.peek() == ' ': + while self.column < indent and self.peek() == " ": self.forward() return chunks, end_mark @@ -1152,34 +1226,33 @@ def scan_flow_scalar(self, style): chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) self.forward() end_mark = self.get_mark() - return ScalarToken(''.join(chunks), False, start_mark, end_mark, - style) + return ScalarToken("".join(chunks), False, start_mark, end_mark, style) ESCAPE_REPLACEMENTS = { - '0': '\0', - 'a': '\x07', - 'b': '\x08', - 't': '\x09', - '\t': '\x09', - 'n': '\x0A', - 'v': '\x0B', - 'f': '\x0C', - 'r': '\x0D', - 'e': '\x1B', - ' ': '\x20', - '\"': '\"', - '\\': '\\', - '/': '/', - 'N': '\x85', - '_': '\xA0', - 'L': '\u2028', - 'P': '\u2029', + "0": "\0", + "a": "\x07", + "b": "\x08", + "t": "\x09", + "\t": "\x09", + "n": "\x0a", + "v": "\x0b", + "f": "\x0c", + "r": "\x0d", + "e": "\x1b", + " ": "\x20", + '"': '"', + "\\": "\\", + "/": "/", + "N": "\x85", + "_": "\xa0", + "L": "\u2028", + "P": "\u2029", } ESCAPE_CODES = { - 'x': 2, - 'u': 4, - 'U': 8, + "x": 2, + "u": 4, + "U": 8, } def scan_flow_scalar_non_spaces(self, double, start_mark): @@ -1187,19 +1260,19 @@ def scan_flow_scalar_non_spaces(self, double, start_mark): chunks = [] while True: length = 0 - while self.peek(length) not in '\'\"\\\0 \t\r\n\x85\u2028\u2029': + while self.peek(length) not in "'\"\\\0 \t\r\n\x85\u2028\u2029": length += 1 if length: chunks.append(self.prefix(length)) self.forward(length) ch = self.peek() - if not double and ch == '\'' and self.peek(1) == '\'': - chunks.append('\'') + if not double and ch == "'" and self.peek(1) == "'": + chunks.append("'") self.forward(2) - elif (double and ch == '\'') or (not double and ch in '\"\\'): + elif (double and ch == "'") or (not double and ch in '"\\'): chunks.append(ch) self.forward() - elif double and ch == '\\': + elif double and ch == "\\": self.forward() ch = self.peek() if ch in self.ESCAPE_REPLACEMENTS: @@ -1209,19 +1282,27 @@ def scan_flow_scalar_non_spaces(self, double, start_mark): length = self.ESCAPE_CODES[ch] self.forward() for k in range(length): - if self.peek(k) not in '0123456789ABCDEFabcdef': - raise ScannerError("while scanning a double-quoted scalar", start_mark, - "expected escape sequence of %d hexdecimal numbers, but found %r" % - (length, self.peek(k)), self.get_mark()) + if self.peek(k) not in "0123456789ABCDEFabcdef": + raise ScannerError( + "while scanning a double-quoted scalar", + start_mark, + "expected escape sequence of %d hexdecimal numbers, but found %r" + % (length, self.peek(k)), + self.get_mark(), + ) code = int(self.prefix(length), 16) chunks.append(chr(code)) self.forward(length) - elif ch in '\r\n\x85\u2028\u2029': + elif ch in "\r\n\x85\u2028\u2029": self.scan_line_break() chunks.extend(self.scan_flow_scalar_breaks(double, start_mark)) else: - raise ScannerError("while scanning a double-quoted scalar", start_mark, - "found unknown escape character %r" % ch, self.get_mark()) + raise ScannerError( + "while scanning a double-quoted scalar", + start_mark, + "found unknown escape character %r" % ch, + self.get_mark(), + ) else: return chunks @@ -1229,21 +1310,25 @@ def scan_flow_scalar_spaces(self, double, start_mark): # See the specification for details. chunks = [] length = 0 - while self.peek(length) in ' \t': + while self.peek(length) in " \t": length += 1 whitespaces = self.prefix(length) self.forward(length) ch = self.peek() - if ch == '\0': - raise ScannerError("while scanning a quoted scalar", start_mark, - "found unexpected end of stream", self.get_mark()) - elif ch in '\r\n\x85\u2028\u2029': + if ch == "\0": + raise ScannerError( + "while scanning a quoted scalar", + start_mark, + "found unexpected end of stream", + self.get_mark(), + ) + elif ch in "\r\n\x85\u2028\u2029": line_break = self.scan_line_break() breaks = self.scan_flow_scalar_breaks(double, start_mark) - if line_break != '\n': + if line_break != "\n": chunks.append(line_break) elif not breaks: - chunks.append(' ') + chunks.append(" ") chunks.extend(breaks) else: chunks.append(whitespaces) @@ -1256,13 +1341,18 @@ def scan_flow_scalar_breaks(self, double, start_mark): # Instead of checking indentation, we check for document # separators. prefix = self.prefix(3) - if (prefix == '---' or prefix == '...') \ - and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': - raise ScannerError("while scanning a quoted scalar", start_mark, - "found unexpected document separator", self.get_mark()) - while self.peek() in ' \t': + if (prefix == "---" or prefix == "...") and self.peek( + 3 + ) in "\0 \t\r\n\x85\u2028\u2029": + raise ScannerError( + "while scanning a quoted scalar", + start_mark, + "found unexpected document separator", + self.get_mark(), + ) + while self.peek() in " \t": self.forward() - if self.peek() in '\r\n\x85\u2028\u2029': + if self.peek() in "\r\n\x85\u2028\u2029": chunks.append(self.scan_line_break()) else: return chunks @@ -1276,23 +1366,28 @@ def scan_plain(self): chunks = [] start_mark = self.get_mark() end_mark = start_mark - indent = self.indent+1 + indent = self.indent + 1 # We allow zero indentation for scalars, but then we need to check for # document separators at the beginning of the line. - #if indent == 0: + # if indent == 0: # indent = 1 spaces = [] while True: length = 0 - if self.peek() == '#': + if self.peek() == "#": break while True: ch = self.peek(length) - if ch in '\0 \t\r\n\x85\u2028\u2029' \ - or (ch == ':' and - self.peek(length+1) in '\0 \t\r\n\x85\u2028\u2029' - + (u',[]{}' if self.flow_level else u''))\ - or (self.flow_level and ch in ',?[]{}'): + if ( + ch in "\0 \t\r\n\x85\u2028\u2029" + or ( + ch == ":" + and self.peek(length + 1) + in "\0 \t\r\n\x85\u2028\u2029" + + (",[]{}" if self.flow_level else "") + ) + or (self.flow_level and ch in ",?[]{}") + ): break length += 1 if length == 0: @@ -1303,10 +1398,13 @@ def scan_plain(self): self.forward(length) end_mark = self.get_mark() spaces = self.scan_plain_spaces(indent, start_mark) - if not spaces or self.peek() == '#' \ - or (not self.flow_level and self.column < indent): + if ( + not spaces + or self.peek() == "#" + or (not self.flow_level and self.column < indent) + ): break - return ScalarToken(''.join(chunks), True, start_mark, end_mark) + return ScalarToken("".join(chunks), True, start_mark, end_mark) def scan_plain_spaces(self, indent, start_mark): # See the specification for details. @@ -1314,32 +1412,34 @@ def scan_plain_spaces(self, indent, start_mark): # We just forbid them completely. Do not use tabs in YAML! chunks = [] length = 0 - while self.peek(length) in ' ': + while self.peek(length) in " ": length += 1 whitespaces = self.prefix(length) self.forward(length) ch = self.peek() - if ch in '\r\n\x85\u2028\u2029': + if ch in "\r\n\x85\u2028\u2029": line_break = self.scan_line_break() self.allow_simple_key = True prefix = self.prefix(3) - if (prefix == '---' or prefix == '...') \ - and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + if (prefix == "---" or prefix == "...") and self.peek( + 3 + ) in "\0 \t\r\n\x85\u2028\u2029": return breaks = [] - while self.peek() in ' \r\n\x85\u2028\u2029': - if self.peek() == ' ': + while self.peek() in " \r\n\x85\u2028\u2029": + if self.peek() == " ": self.forward() else: breaks.append(self.scan_line_break()) prefix = self.prefix(3) - if (prefix == '---' or prefix == '...') \ - and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + if (prefix == "---" or prefix == "...") and self.peek( + 3 + ) in "\0 \t\r\n\x85\u2028\u2029": return - if line_break != '\n': + if line_break != "\n": chunks.append(line_break) elif not breaks: - chunks.append(' ') + chunks.append(" ") chunks.extend(breaks) elif whitespaces: chunks.append(whitespaces) @@ -1350,20 +1450,29 @@ def scan_tag_handle(self, name, start_mark): # For some strange reasons, the specification does not allow '_' in # tag handles. I have allowed it anyway. ch = self.peek() - if ch != '!': - raise ScannerError("while scanning a %s" % name, start_mark, - "expected '!', but found %r" % ch, self.get_mark()) + if ch != "!": + raise ScannerError( + "while scanning a %s" % name, + start_mark, + "expected '!', but found %r" % ch, + self.get_mark(), + ) length = 1 ch = self.peek(length) - if ch != ' ': - while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-_': + if ch != " ": + while ( + "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_" + ): length += 1 ch = self.peek(length) - if ch != '!': + if ch != "!": self.forward(length) - raise ScannerError("while scanning a %s" % name, start_mark, - "expected '!', but found %r" % ch, self.get_mark()) + raise ScannerError( + "while scanning a %s" % name, + start_mark, + "expected '!', but found %r" % ch, + self.get_mark(), + ) length += 1 value = self.prefix(length) self.forward(length) @@ -1375,9 +1484,13 @@ def scan_tag_uri(self, name, start_mark): chunks = [] length = 0 ch = self.peek(length) - while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ - or ch in '-;/?:@&=+$,_.!~*\'()[]%': - if ch == '%': + while ( + "0" <= ch <= "9" + or "A" <= ch <= "Z" + or "a" <= ch <= "z" + or ch in "-;/?:@&=+$,_.!~*'()[]%" + ): + if ch == "%": chunks.append(self.prefix(length)) self.forward(length) length = 0 @@ -1390,25 +1503,33 @@ def scan_tag_uri(self, name, start_mark): self.forward(length) length = 0 if not chunks: - raise ScannerError("while parsing a %s" % name, start_mark, - "expected URI, but found %r" % ch, self.get_mark()) - return ''.join(chunks) + raise ScannerError( + "while parsing a %s" % name, + start_mark, + "expected URI, but found %r" % ch, + self.get_mark(), + ) + return "".join(chunks) def scan_uri_escapes(self, name, start_mark): # See the specification for details. codes = [] mark = self.get_mark() - while self.peek() == '%': + while self.peek() == "%": self.forward() for k in range(2): - if self.peek(k) not in '0123456789ABCDEFabcdef': - raise ScannerError("while scanning a %s" % name, start_mark, - "expected URI escape sequence of 2 hexdecimal numbers, but found %r" - % self.peek(k), self.get_mark()) + if self.peek(k) not in "0123456789ABCDEFabcdef": + raise ScannerError( + "while scanning a %s" % name, + start_mark, + "expected URI escape sequence of 2 hexdecimal numbers, but found %r" + % self.peek(k), + self.get_mark(), + ) codes.append(int(self.prefix(2), 16)) self.forward(2) try: - value = bytes(codes).decode('utf-8') + value = bytes(codes).decode("utf-8") except UnicodeDecodeError as exc: raise ScannerError("while scanning a %s" % name, start_mark, str(exc), mark) return value @@ -1423,13 +1544,13 @@ def scan_line_break(self): # '\u2029 : '\u2029' # default : '' ch = self.peek() - if ch in '\r\n\x85': - if self.prefix(2) == '\r\n': + if ch in "\r\n\x85": + if self.prefix(2) == "\r\n": self.forward(2) else: self.forward() - return '\n' - elif ch in '\u2028\u2029': + return "\n" + elif ch in "\u2028\u2029": self.forward() return ch - return '' + return "" diff --git a/metaflow/_vendor/yaml/serializer.py b/metaflow/_vendor/yaml/serializer.py index fe911e67ae7..92f9221f807 100644 --- a/metaflow/_vendor/yaml/serializer.py +++ b/metaflow/_vendor/yaml/serializer.py @@ -1,19 +1,26 @@ - -__all__ = ['Serializer', 'SerializerError'] +__all__ = ["Serializer", "SerializerError"] from .error import YAMLError from .events import * from .nodes import * + class SerializerError(YAMLError): pass + class Serializer: - ANCHOR_TEMPLATE = 'id%03d' + ANCHOR_TEMPLATE = "id%03d" - def __init__(self, encoding=None, - explicit_start=None, explicit_end=None, version=None, tags=None): + def __init__( + self, + encoding=None, + explicit_start=None, + explicit_end=None, + version=None, + tags=None, + ): self.use_encoding = encoding self.use_explicit_start = explicit_start self.use_explicit_end = explicit_end @@ -40,7 +47,7 @@ def close(self): self.emit(StreamEndEvent()) self.closed = True - #def __del__(self): + # def __del__(self): # self.close() def serialize(self, node): @@ -48,8 +55,13 @@ def serialize(self, node): raise SerializerError("serializer is not opened") elif self.closed: raise SerializerError("serializer is closed") - self.emit(DocumentStartEvent(explicit=self.use_explicit_start, - version=self.use_version, tags=self.use_tags)) + self.emit( + DocumentStartEvent( + explicit=self.use_explicit_start, + version=self.use_version, + tags=self.use_tags, + ) + ) self.anchor_node(node) self.serialize_node(node, None, None) self.emit(DocumentEndEvent(explicit=self.use_explicit_end)) @@ -86,26 +98,30 @@ def serialize_node(self, node, parent, index): detected_tag = self.resolve(ScalarNode, node.value, (True, False)) default_tag = self.resolve(ScalarNode, node.value, (False, True)) implicit = (node.tag == detected_tag), (node.tag == default_tag) - self.emit(ScalarEvent(alias, node.tag, implicit, node.value, - style=node.style)) + self.emit( + ScalarEvent(alias, node.tag, implicit, node.value, style=node.style) + ) elif isinstance(node, SequenceNode): - implicit = (node.tag - == self.resolve(SequenceNode, node.value, True)) - self.emit(SequenceStartEvent(alias, node.tag, implicit, - flow_style=node.flow_style)) + implicit = node.tag == self.resolve(SequenceNode, node.value, True) + self.emit( + SequenceStartEvent( + alias, node.tag, implicit, flow_style=node.flow_style + ) + ) index = 0 for item in node.value: self.serialize_node(item, node, index) index += 1 self.emit(SequenceEndEvent()) elif isinstance(node, MappingNode): - implicit = (node.tag - == self.resolve(MappingNode, node.value, True)) - self.emit(MappingStartEvent(alias, node.tag, implicit, - flow_style=node.flow_style)) + implicit = node.tag == self.resolve(MappingNode, node.value, True) + self.emit( + MappingStartEvent( + alias, node.tag, implicit, flow_style=node.flow_style + ) + ) for key, value in node.value: self.serialize_node(key, node, None) self.serialize_node(value, node, key) self.emit(MappingEndEvent()) self.ascend_resolver() - diff --git a/metaflow/_vendor/yaml/tokens.py b/metaflow/_vendor/yaml/tokens.py index 4d0b48a394a..235ab49d66c 100644 --- a/metaflow/_vendor/yaml/tokens.py +++ b/metaflow/_vendor/yaml/tokens.py @@ -1,104 +1,129 @@ - class Token(object): def __init__(self, start_mark, end_mark): self.start_mark = start_mark self.end_mark = end_mark + def __repr__(self): - attributes = [key for key in self.__dict__ - if not key.endswith('_mark')] + attributes = [key for key in self.__dict__ if not key.endswith("_mark")] attributes.sort() - arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) - for key in attributes]) - return '%s(%s)' % (self.__class__.__name__, arguments) + arguments = ", ".join( + ["%s=%r" % (key, getattr(self, key)) for key in attributes] + ) + return "%s(%s)" % (self.__class__.__name__, arguments) -#class BOMToken(Token): + +# class BOMToken(Token): # id = '' + class DirectiveToken(Token): - id = '' + id = "" + def __init__(self, name, value, start_mark, end_mark): self.name = name self.value = value self.start_mark = start_mark self.end_mark = end_mark + class DocumentStartToken(Token): - id = '' + id = "" + class DocumentEndToken(Token): - id = '' + id = "" + class StreamStartToken(Token): - id = '' - def __init__(self, start_mark=None, end_mark=None, - encoding=None): + id = "" + + def __init__(self, start_mark=None, end_mark=None, encoding=None): self.start_mark = start_mark self.end_mark = end_mark self.encoding = encoding + class StreamEndToken(Token): - id = '' + id = "" + class BlockSequenceStartToken(Token): - id = '' + id = "" + class BlockMappingStartToken(Token): - id = '' + id = "" + class BlockEndToken(Token): - id = '' + id = "" + class FlowSequenceStartToken(Token): - id = '[' + id = "[" + class FlowMappingStartToken(Token): - id = '{' + id = "{" + class FlowSequenceEndToken(Token): - id = ']' + id = "]" + class FlowMappingEndToken(Token): - id = '}' + id = "}" + class KeyToken(Token): - id = '?' + id = "?" + class ValueToken(Token): - id = ':' + id = ":" + class BlockEntryToken(Token): - id = '-' + id = "-" + class FlowEntryToken(Token): - id = ',' + id = "," + class AliasToken(Token): - id = '' + id = "" + def __init__(self, value, start_mark, end_mark): self.value = value self.start_mark = start_mark self.end_mark = end_mark + class AnchorToken(Token): - id = '' + id = "" + def __init__(self, value, start_mark, end_mark): self.value = value self.start_mark = start_mark self.end_mark = end_mark + class TagToken(Token): - id = '' + id = "" + def __init__(self, value, start_mark, end_mark): self.value = value self.start_mark = start_mark self.end_mark = end_mark + class ScalarToken(Token): - id = '' + id = "" + def __init__(self, value, plain, start_mark, end_mark, style=None): self.value = value self.plain = plain self.start_mark = start_mark self.end_mark = end_mark self.style = style - diff --git a/metaflow/_vendor/zipp.py b/metaflow/_vendor/zipp.py index 26b723c1fd3..72632b0b773 100644 --- a/metaflow/_vendor/zipp.py +++ b/metaflow/_vendor/zipp.py @@ -12,7 +12,7 @@ OrderedDict = dict -__all__ = ['Path'] +__all__ = ["Path"] def _parents(path): @@ -93,7 +93,7 @@ def resolve_dir(self, name): as a directory (with the trailing slash). """ names = self._name_set() - dirname = name + '/' + dirname = name + "/" dir_match = name not in names and dirname in names return dirname if dir_match else name @@ -110,7 +110,7 @@ def make(cls, source): return cls(_pathlib_compat(source)) # Only allow for FastLookup when supplied zipfile is read-only - if 'r' not in source.mode: + if "r" not in source.mode: cls = CompleteDirs source.__class__ = cls @@ -240,7 +240,7 @@ def __init__(self, root, at=""): self.root = FastLookup.make(root) self.at = at - def open(self, mode='r', *args, pwd=None, **kwargs): + def open(self, mode="r", *args, pwd=None, **kwargs): """ Open this entry as text or binary following the semantics of ``pathlib.Path.open()`` by passing arguments through @@ -249,10 +249,10 @@ def open(self, mode='r', *args, pwd=None, **kwargs): if self.is_dir(): raise IsADirectoryError(self) zip_mode = mode[0] - if not self.exists() and zip_mode == 'r': + if not self.exists() and zip_mode == "r": raise FileNotFoundError(self) stream = self.root.open(self.at, zip_mode, pwd=pwd) - if 'b' in mode: + if "b" in mode: if args or kwargs: raise ValueError("encoding args invalid for binary operation") return stream @@ -279,11 +279,11 @@ def filename(self): return pathlib.Path(self.root.filename).joinpath(self.at) def read_text(self, *args, **kwargs): - with self.open('r', *args, **kwargs) as strm: + with self.open("r", *args, **kwargs) as strm: return strm.read() def read_bytes(self): - with self.open('rb') as strm: + with self.open("rb") as strm: return strm.read() def _is_child(self, path): @@ -323,7 +323,7 @@ def joinpath(self, *other): def parent(self): if not self.at: return self.filename.parent - parent_at = posixpath.dirname(self.at.rstrip('/')) + parent_at = posixpath.dirname(self.at.rstrip("/")) if parent_at: - parent_at += '/' + parent_at += "/" return self._next(parent_at) diff --git a/metaflow/cli.py b/metaflow/cli.py index cb9a0bc1ac9..9f292a44fcb 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -25,6 +25,7 @@ DEFAULT_METADATA, DEFAULT_MONITOR, DEFAULT_PACKAGE_SUFFIXES, + DISABLE_LOGGING, ) from .metaflow_current import current from .metaflow_environment import MetaflowEnvironment @@ -56,6 +57,8 @@ def echo_dev_null(*args, **kwargs): def echo_always(line, **kwargs): + if DISABLE_LOGGING: + return if kwargs.pop("wrap", False): import textwrap @@ -105,6 +108,8 @@ def echo_always(line, **kwargs): def logger(body="", system_msg=False, head="", bad=False, timestamp=True, nl=True): + if DISABLE_LOGGING: + return if timestamp: if timestamp is True: dt = datetime.now() diff --git a/metaflow/cmd/util.py b/metaflow/cmd/util.py index ceff9869c25..f95796fa8aa 100644 --- a/metaflow/cmd/util.py +++ b/metaflow/cmd/util.py @@ -20,4 +20,8 @@ def echo_dev_null(*args, **kwargs): def echo_always(line, **kwargs): + from metaflow.metaflow_config import DISABLE_LOGGING + + if DISABLE_LOGGING: + return click.secho(line, **kwargs) diff --git a/metaflow/constants.py b/metaflow/constants.py new file mode 100644 index 00000000000..54b762596cb --- /dev/null +++ b/metaflow/constants.py @@ -0,0 +1,4 @@ +# Constants used across Metaflow for configuration and local storage + +DATASTORE_LOCAL_DIR = ".metaflow" +LOCAL_CONFIG_FILE = "config.json" diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 8f042887bd8..f2ec5265f25 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -1,10 +1,14 @@ import os import sys import types +from .constants import DATASTORE_LOCAL_DIR, LOCAL_CONFIG_FILE from metaflow.exception import MetaflowException from metaflow.metaflow_config_funcs import from_conf, get_validate_choice_fn +# Option to disable all Metaflow internal logging/status output +DISABLE_LOGGING = from_conf("DISABLE_LOGGING", False) + # Disable multithreading security on MacOS if sys.platform == "darwin": os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES" diff --git a/metaflow/metaflow_config_funcs.py b/metaflow/metaflow_config_funcs.py index 863d6f1bc6d..f0f1e5ea63a 100644 --- a/metaflow/metaflow_config_funcs.py +++ b/metaflow/metaflow_config_funcs.py @@ -40,7 +40,7 @@ def init_local_config(): # check in DATASTORE_SYSROOT_LOCAL but only up the current getcwd() path. This also # prevents nasty circular dependencies :) - from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, LOCAL_CONFIG_FILE + from .constants import DATASTORE_LOCAL_DIR, LOCAL_CONFIG_FILE current_path = os.getcwd() check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) diff --git a/metaflow/metaflow_git.py b/metaflow/metaflow_git.py index b90a2edb9b6..1ee1880a4a6 100644 --- a/metaflow/metaflow_git.py +++ b/metaflow/metaflow_git.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """Get git repository information for the package -Functions to retrieve git repository details like URL, branch name, +Functions to retrieve git repository details like URL, branch name, and commit SHA for Metaflow code provenance tracking. """ diff --git a/metaflow/plugins/cards/card_datastore.py b/metaflow/plugins/cards/card_datastore.py index f70f608c372..79a6a31c381 100644 --- a/metaflow/plugins/cards/card_datastore.py +++ b/metaflow/plugins/cards/card_datastore.py @@ -1,6 +1,4 @@ -""" - -""" +""" """ from collections import namedtuple from io import BytesIO diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index da0f0fc3130..31ea40f2065 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -533,7 +533,7 @@ def environment_variable_from_selector(self, name, label_value): return self self._kwargs["environment_variables_from_selectors"] = dict( self._kwargs.get("environment_variables_from_selectors", {}), - **{name: label_value} + **{name: label_value}, ) return self diff --git a/metaflow/plugins/logs_cli.py b/metaflow/plugins/logs_cli.py index b5314bf40b6..06e68b4bb85 100644 --- a/metaflow/plugins/logs_cli.py +++ b/metaflow/plugins/logs_cli.py @@ -173,7 +173,14 @@ def show( if ds_list: + try: + from metaflow.metaflow_config import DISABLE_LOGGING + except Exception: + DISABLE_LOGGING = False + def echo_unicode(line, **kwargs): + if DISABLE_LOGGING: + return click.secho(line.decode("UTF-8", errors="replace"), **kwargs) # old style logs are non mflog-style logs diff --git a/metaflow/system/system_utils.py b/metaflow/system/system_utils.py index 0b6acd58b57..825c0b514bf 100644 --- a/metaflow/system/system_utils.py +++ b/metaflow/system/system_utils.py @@ -8,7 +8,7 @@ def __init__(self, name="not_a_real_flow"): # This function is used to initialize the environment outside a flow. def init_environment_outside_flow( - flow: Union["metaflow.flowspec.FlowSpec", "metaflow.sidecar.DummyFlow"] + flow: Union["metaflow.flowspec.FlowSpec", "metaflow.sidecar.DummyFlow"], ) -> "metaflow.metaflow_environment.MetaflowEnvironment": from metaflow.plugins import ENVIRONMENTS from metaflow.metaflow_config import DEFAULT_ENVIRONMENT