From 6d4df112ae547a3d3bdcf85b0ffa88065b051df3 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 23 Feb 2019 17:31:37 +0100 Subject: [PATCH 01/65] Makefile: Allow setting PREFIX from the command line --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index e88a025421..ae8737aea2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PREFIX = ~/.micropython/lib +PREFIX ?= ~/.micropython/lib all: From 90105faf786fd4744bee27d07bd9b5d29f8e979b Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 23 Feb 2019 17:32:18 +0100 Subject: [PATCH 02/65] Implement platform.python_implementation() --- platform/metadata.txt | 3 ++- platform/platform.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/platform/metadata.txt b/platform/metadata.txt index fda992a9c0..c1d68e096a 100644 --- a/platform/metadata.txt +++ b/platform/metadata.txt @@ -1,3 +1,4 @@ srctype=dummy type=module -version = 0.0.2 +version = 0.1.0 +desc = Dummy platform module for MicroPython diff --git a/platform/platform.py b/platform/platform.py index e69de29bb2..b74d6378f5 100644 --- a/platform/platform.py +++ b/platform/platform.py @@ -0,0 +1,3 @@ +# dummy +def python_implementation(): + return "micro" From a891c99da479549e75b9c97ece4c0e15232b1228 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 23 Feb 2019 17:45:58 +0100 Subject: [PATCH 03/65] New module: collections.abc --- collections.abc/collections/abc.py | 5 +++++ collections.abc/metadata.txt | 3 +++ 2 files changed, 8 insertions(+) create mode 100644 collections.abc/collections/abc.py create mode 100644 collections.abc/metadata.txt diff --git a/collections.abc/collections/abc.py b/collections.abc/collections/abc.py new file mode 100644 index 0000000000..9ca6cacf93 --- /dev/null +++ b/collections.abc/collections/abc.py @@ -0,0 +1,5 @@ + +# this is so not-right it's not even wrong +Mapping = dict +MutableMapping = dict +Sequence = (tuple, list) # only useful for simple isinstance tests diff --git a/collections.abc/metadata.txt b/collections.abc/metadata.txt new file mode 100644 index 0000000000..a8ff4f11f6 --- /dev/null +++ b/collections.abc/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = package +version = 0.1.0 From 7bd5d3d17d87dbd31f8a9ad6cd96cba9212005b8 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:15:13 +0100 Subject: [PATCH 04/65] Basic "attrs" module Limitation: no positional arguments --- attrs/attr.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++ attrs/metadata.txt | 3 +++ 2 files changed, 69 insertions(+) create mode 100644 attrs/attr.py create mode 100644 attrs/metadata.txt diff --git a/attrs/attr.py b/attrs/attr.py new file mode 100644 index 0000000000..34d845b0c5 --- /dev/null +++ b/attrs/attr.py @@ -0,0 +1,66 @@ +# micropython doesn't have metaclasses or introspection +# so we need to do things annoyingly differently + +class attrib: + # single attribute + def __init__(self, **kw): + for k,v in kw.items(): + setattr(self,k,v) + +class validators: + @staticmethod + def instance_of(exc): + pass + +attr = attrib +ib = attrib + +class Factory: + def __init__(self, f): + self.f = f + +def _init(self, **kw): + seen = set() + for k,v in kw.items(): + uk = '_'+k + if k[0] != '_' and hasattr(self.__class__, uk): + seen.add(uk) + setattr(self,uk,v) + else: + setattr(self,k,v) + + for k in dir(self.__class__): + if len(k)>1 and k[0] == '_' and k[1] == '_': + continue + if k in kw or k in seen: + continue + v = getattr(self.__class__,k) + + if isinstance(v,attrib): + if hasattr(v,'factory'): + v = v.factory() + else: + v = v.default + if isinstance(v,Factory): + v = v.f() + setattr(self,k,v) + +def attrs(cls=None, **kw): + def _attrs(cls): + f = [] + for k in dir(cls): + v = getattr(cls,k) + if isinstance(v,attrib): + f.append((k,v)) + + cls._attrs = f + if kw.get('init', True): + cls.__init__ = _init + return cls + + if cls is not None: + return _attrs(cls) + return _attrs + +attributess = attrs +s = attrs diff --git a/attrs/metadata.txt b/attrs/metadata.txt new file mode 100644 index 0000000000..e26221407a --- /dev/null +++ b/attrs/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = module +version = 0.1.0 From 641fd9fa172924c7c7129bba50608932592f56e3 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:17:25 +0100 Subject: [PATCH 05/65] add outcome --- outcome/metadata.txt | 3 + outcome/outcome/__init__.py | 20 ++++++ outcome/outcome/_async.py | 65 +++++++++++++++++++ outcome/outcome/_sync.py | 122 ++++++++++++++++++++++++++++++++++++ outcome/outcome/_util.py | 27 ++++++++ outcome/outcome/_version.py | 4 ++ 6 files changed, 241 insertions(+) create mode 100644 outcome/metadata.txt create mode 100644 outcome/outcome/__init__.py create mode 100644 outcome/outcome/_async.py create mode 100644 outcome/outcome/_sync.py create mode 100644 outcome/outcome/_util.py create mode 100644 outcome/outcome/_version.py diff --git a/outcome/metadata.txt b/outcome/metadata.txt new file mode 100644 index 0000000000..a8ff4f11f6 --- /dev/null +++ b/outcome/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = package +version = 0.1.0 diff --git a/outcome/outcome/__init__.py b/outcome/outcome/__init__.py new file mode 100644 index 0000000000..6f93754757 --- /dev/null +++ b/outcome/outcome/__init__.py @@ -0,0 +1,20 @@ +# coding: utf-8 +"""Top-level package for outcome.""" +from __future__ import absolute_import, division, print_function + +import sys + +from ._util import AlreadyUsedError, fixup_module_metadata +from ._version import __version__ + +if sys.version_info >= (3, 5): + from ._async import Error, Outcome, Value, acapture, capture + __all__ = ( + 'Error', 'Outcome', 'Value', 'acapture', 'capture', 'AlreadyUsedError' + ) +else: + from ._sync import Error, Outcome, Value, capture + __all__ = ('Error', 'Outcome', 'Value', 'capture', 'AlreadyUsedError') + +fixup_module_metadata(__name__, globals()) +del fixup_module_metadata diff --git a/outcome/outcome/_async.py b/outcome/outcome/_async.py new file mode 100644 index 0000000000..74e15586a5 --- /dev/null +++ b/outcome/outcome/_async.py @@ -0,0 +1,65 @@ +import abc + +from ._sync import Error as ErrorBase +from ._sync import Outcome as OutcomeBase +from ._sync import Value as ValueBase + +__all__ = ['Error', 'Outcome', 'Value', 'acapture', 'capture'] + + +def capture(sync_fn, *args, **kwargs): + """Run ``sync_fn(*args, **kwargs)`` and capture the result. + + Returns: + Either a :class:`Value` or :class:`Error` as appropriate. + + """ + # _sync.capture references ErrorBase and ValueBase + try: + return Value(sync_fn(*args, **kwargs)) + except BaseException as exc: + return Error(exc) + + +async def acapture(async_fn, *args, **kwargs): + """Run ``await async_fn(*args, **kwargs)`` and capture the result. + + Returns: + Either a :class:`Value` or :class:`Error` as appropriate. + + """ + try: + return Value(await async_fn(*args, **kwargs)) + except BaseException as exc: + return Error(exc) + + +class Outcome(OutcomeBase): + @abc.abstractmethod + async def asend(self, agen): + """Send or throw the contained value or exception into the given async + generator object. + + Args: + agen: An async generator object supporting ``.asend()`` and + ``.athrow()`` methods. + + """ + + +class Value(Outcome, ValueBase): + async def asend(self, agen): + self._set_unwrapped() + return await agen.asend(self.value) + + +class Error(Outcome, ErrorBase): + async def asend(self, agen): + self._set_unwrapped() + return await agen.athrow(self.error) + + +# We don't need this for Sphinx, but do it anyway for IPython, IDEs, etc +Outcome.__doc__ = OutcomeBase.__doc__ +Value.__doc__ = ValueBase.__doc__ +Error.__doc__ = ErrorBase.__doc__ diff --git a/outcome/outcome/_sync.py b/outcome/outcome/_sync.py new file mode 100644 index 0000000000..f69c4d0eaa --- /dev/null +++ b/outcome/outcome/_sync.py @@ -0,0 +1,122 @@ +# coding: utf-8 +from __future__ import absolute_import, division, print_function + +import abc + +import attr + +from ._util import AlreadyUsedError + +__all__ = ['Error', 'Outcome', 'Value', 'capture'] + + +def capture(sync_fn, *args, **kwargs): + """Run ``sync_fn(*args, **kwargs)`` and capture the result. + + Returns: + Either a :class:`Value` or :class:`Error` as appropriate. + + """ + try: + return Value(sync_fn(*args, **kwargs)) + except BaseException as exc: + return Error(exc) + + +@attr.s(repr=False, init=False, slots=True, init=False) +class Outcome: + """An abstract class representing the result of a Python computation. + + This class has two concrete subclasses: :class:`Value` representing a + value, and :class:`Error` representing an exception. + + In addition to the methods described below, comparison operators on + :class:`Value` and :class:`Error` objects (``==``, ``<``, etc.) check that + the other object is also a :class:`Value` or :class:`Error` object + respectively, and then compare the contained objects. + + :class:`Outcome` objects are hashable if the contained objects are + hashable. + + """ + _unwrapped = None + + def __init__(self): + self._unwrapped = False + + def _set_unwrapped(self): + if self._unwrapped: + raise AlreadyUsedError + self._unwrapped = True + + @abc.abstractmethod + def unwrap(self): + """Return or raise the contained value or exception. + + These two lines of code are equivalent:: + + x = fn(*args) + x = outcome.capture(fn, *args).unwrap() + + """ + + @abc.abstractmethod + def send(self, gen): + """Send or throw the contained value or exception into the given + generator object. + + Args: + gen: A generator object supporting ``.send()`` and ``.throw()`` + methods. + + """ + + +@attr.s(frozen=True, repr=False, slots=True, init=False) +class Value(Outcome): + """Concrete :class:`Outcome` subclass representing a regular value. + + """ + + value = None + + def __init__(self, value): + super().__init__() + self.value = value + + def __repr__(self): + return 'Value({!r})'.format(self.value) + + def unwrap(self): + self._set_unwrapped() + return self.value + + def send(self, gen): + self._set_unwrapped() + return gen.send(self.value) + + +@attr.s(frozen=True, repr=False, slots=True, init=False) +class Error(Outcome): + """Concrete :class:`Outcome` subclass representing a raised exception. + + """ + + error = None + def __init__(self, error): + super().__init__() + self.error = error + + def __repr__(self): + return 'Error({!r})'.format(self.error) + + def unwrap(self): + self._set_unwrapped() + # Tracebacks show the 'raise' line below out of context, so let's give + # this variable a name that makes sense out of context. + captured_error = self.error + raise captured_error + + def send(self, it): + self._set_unwrapped() + return it.throw(self.error) diff --git a/outcome/outcome/_util.py b/outcome/outcome/_util.py new file mode 100644 index 0000000000..91b1d6c091 --- /dev/null +++ b/outcome/outcome/_util.py @@ -0,0 +1,27 @@ +# coding: utf-8 +from __future__ import absolute_import, division, print_function + +import abc +import sys + + +class AlreadyUsedError(RuntimeError): + """An Outcome can only be unwrapped once.""" + pass + + +def fixup_module_metadata(module_name, namespace): + def fix_one(obj): + mod = getattr(obj, "__module__", None) + if mod is not None and mod.startswith("outcome."): + obj.__module__ = module_name + if isinstance(obj, type): + for k in dir(obj): + attr_value = getattr(obj, k) + fix_one(attr_value) + + for objname in namespace["__all__"]: + obj = namespace[objname] + fix_one(obj) + + diff --git a/outcome/outcome/_version.py b/outcome/outcome/_version.py new file mode 100644 index 0000000000..1121db00bf --- /dev/null +++ b/outcome/outcome/_version.py @@ -0,0 +1,4 @@ +# coding: utf-8 +# This file is imported from __init__.py and exec'd from setup.py + +__version__ = "1.0.0" From 14a9f9005787658556023d6df19e1ba249e827b7 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:20:03 +0100 Subject: [PATCH 06/65] async_generator --- async_generator/async_generator.py | 481 +++++++++++++++++++++++++++++ async_generator/metadata.txt | 3 + 2 files changed, 484 insertions(+) create mode 100644 async_generator/async_generator.py create mode 100644 async_generator/metadata.txt diff --git a/async_generator/async_generator.py b/async_generator/async_generator.py new file mode 100644 index 0000000000..2c7a344eb0 --- /dev/null +++ b/async_generator/async_generator.py @@ -0,0 +1,481 @@ + +import sys +from functools import wraps + + +class aclosing: + def __init__(self, aiter): + self._aiter = aiter + + async def __aenter__(self): + return self._aiter + + async def __aexit__(self, *args): + await self._aiter.aclose() + + +import inspect +import collections.abc + + +class YieldWrapper: + def __init__(self, payload): + self.payload = payload + + +def _wrap(value): + return YieldWrapper(value) + + +def _is_wrapped(box): + return isinstance(box, YieldWrapper) + + +def _unwrap(box): + return box.payload + + +# This is the magic code that lets you use yield_ and yield_from_ with native +# generators. +# +# The old version worked great on Linux and MacOS, but not on Windows, because +# it depended on _PyAsyncGenValueWrapperNew. The new version segfaults +# everywhere, and I'm not sure why -- probably my lack of understanding +# of ctypes and refcounts. +# +# There are also some commented out tests that should be re-enabled if this is +# fixed: +# +# if sys.version_info >= (3, 6): +# # Use the same box type that the interpreter uses internally. This allows +# # yield_ and (more importantly!) yield_from_ to work in built-in +# # generators. +# import ctypes # mua ha ha. +# +# # We used to call _PyAsyncGenValueWrapperNew to create and set up new +# # wrapper objects, but that symbol isn't available on Windows: +# # +# # https://github.com/python-trio/async_generator/issues/5 +# # +# # Fortunately, the type object is available, but it means we have to do +# # this the hard way. +# +# # We don't actually need to access this, but we need to make a ctypes +# # structure so we can call addressof. +# class _ctypes_PyTypeObject(ctypes.Structure): +# pass +# _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof( +# _ctypes_PyTypeObject.in_dll( +# ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type")) +# _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New +# _PyObject_GC_New.restype = ctypes.py_object +# _PyObject_GC_New.argtypes = (ctypes.c_void_p,) +# +# _Py_IncRef = ctypes.pythonapi.Py_IncRef +# _Py_IncRef.restype = None +# _Py_IncRef.argtypes = (ctypes.py_object,) +# +# class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure): +# _fields_ = [ +# ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()), +# ('agw_val', ctypes.py_object), +# ] +# def _wrap(value): +# box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr) +# raw = ctypes.cast(ctypes.c_void_p(id(box)), +# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) +# raw.contents.agw_val = value +# _Py_IncRef(value) +# return box +# +# def _unwrap(box): +# assert _is_wrapped(box) +# raw = ctypes.cast(ctypes.c_void_p(id(box)), +# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) +# value = raw.contents.agw_val +# _Py_IncRef(value) +# return value +# +# _PyAsyncGenWrappedValue_Type = type(_wrap(1)) +# def _is_wrapped(box): +# return isinstance(box, _PyAsyncGenWrappedValue_Type) + + +# The magic @coroutine decorator is how you write the bottom level of +# coroutine stacks -- 'async def' can only use 'await' = yield from; but +# eventually we must bottom out in a @coroutine that calls plain 'yield'. +#@coroutine +def _yield_(value): + return (yield _wrap(value)) + + +# But we wrap the bare @coroutine version in an async def, because async def +# has the magic feature that users can get warnings messages if they forget to +# use 'await'. +async def yield_(value=None): + return await _yield_(value) + + +async def yield_from_(delegate): + # Transcribed with adaptations from: + # + # https://www.python.org/dev/peps/pep-0380/#formal-semantics + # + # This takes advantage of a sneaky trick: if an @async_generator-wrapped + # function calls another async function (like yield_from_), and that + # second async function calls yield_, then because of the hack we use to + # implement yield_, the yield_ will actually propagate through yield_from_ + # back to the @async_generator wrapper. So even though we're a regular + # function, we can directly yield values out of the calling async + # generator. + def unpack_StopAsyncIteration(e): + if e.args: + return e.args[0] + else: + return None + + _i = type(delegate).__aiter__(delegate) + if hasattr(_i, "__await__"): + _i = await _i + try: + _y = await type(_i).__anext__(_i) + except StopAsyncIteration as _e: + _r = unpack_StopAsyncIteration(_e) + else: + while 1: + try: + _s = await yield_(_y) + except GeneratorExit as _e: + try: + _m = _i.aclose + except AttributeError: + pass + else: + await _m() + raise _e + except BaseException as _e: + _x = sys.exc_info() + try: + _m = _i.athrow + except AttributeError: + raise _e + else: + try: + _y = await _m(*_x) + except StopAsyncIteration as _e: + _r = unpack_StopAsyncIteration(_e) + break + else: + try: + if _s is None: + _y = await type(_i).__anext__(_i) + else: + _y = await _i.asend(_s) + except StopAsyncIteration as _e: + _r = unpack_StopAsyncIteration(_e) + break + return _r + + +# This is the awaitable / iterator that implements asynciter.__anext__() and +# friends. +# +# Note: we can be sloppy about the distinction between +# +# type(self._it).__next__(self._it) +# +# and +# +# self._it.__next__() +# +# because we happen to know that self._it is not a general iterator object, +# but specifically a coroutine iterator object where these are equivalent. +class ANextIter: + def __init__(self, it, first_fn, *first_args): + self._it = it + self._first_fn = first_fn + self._first_args = first_args + + def __await__(self): + return self + + def __next__(self): + if self._first_fn is not None: + first_fn = self._first_fn + first_args = self._first_args + self._first_fn = self._first_args = None + return self._invoke(first_fn, *first_args) + else: + return self._invoke(self._it.__next__) + + def send(self, value): + return self._invoke(self._it.send, value) + + def throw(self, type, value=None, traceback=None): + return self._invoke(self._it.throw, type, value, traceback) + + def _invoke(self, fn, *args): + try: + result = fn(*args) + except StopIteration as e: + # The underlying generator returned, so we should signal the end + # of iteration. + raise StopAsyncIteration(e.value) + except StopAsyncIteration as e: + # PEP 479 says: if a generator raises Stop(Async)Iteration, then + # it should be wrapped into a RuntimeError. Python automatically + # enforces this for StopIteration; for StopAsyncIteration we need + # to it ourselves. + raise RuntimeError( + "async_generator raise StopAsyncIteration" + ) from e + if _is_wrapped(result): + raise StopIteration(_unwrap(result)) + else: + return result + + +class AsyncGenerator: + def __init__(self, coroutine): + self._coroutine = coroutine + self._it = coroutine.__await__() + self.ag_running = False + self._finalizer = None + self._closed = False + self._hooks_inited = False + + # Yecchh. + if sys.version_info < (3, 5, 2): + + async def __aiter__(self): + return self + else: + + def __aiter__(self): + return self + + ################################################################ + # Introspection attributes + ################################################################ + + @property + def ag_code(self): + return self._coroutine.cr_code + + @property + def ag_frame(self): + return self._coroutine.cr_frame + + ################################################################ + # Core functionality + ################################################################ + + # These need to return awaitables, rather than being async functions, + # to match the native behavior where the firstiter hook is called + # immediately on asend()/etc, even if the coroutine that asend() + # produces isn't awaited for a bit. + + def __anext__(self): + return self._do_it(self._it.__next__) + + def asend(self, value): + return self._do_it(self._it.send, value) + + def athrow(self, type, value=None, traceback=None): + return self._do_it(self._it.throw, type, value, traceback) + + def _do_it(self, start_fn, *args): + if not self._hooks_inited: + self._hooks_inited = True + (firstiter, self._finalizer) = get_asyncgen_hooks() + if firstiter is not None: + firstiter(self) + + async def step(): + if self.ag_running: + raise ValueError("async generator already executing") + try: + self.ag_running = True + return await ANextIter(self._it, start_fn, *args) + except StopAsyncIteration: + raise + finally: + self.ag_running = False + + return step() + + ################################################################ + # Cleanup + ################################################################ + + async def aclose(self): + # Make sure that even if we raise "async_generator ignored + # GeneratorExit", and thus fail to exhaust the coroutine, + # __del__ doesn't complain again. + self._closed = True + if state is CORO_CREATED: + # Make sure that aclose() on an unstarted generator returns + # successfully and prevents future iteration. + self._it.close() + return + try: + await self.athrow(GeneratorExit) + except (GeneratorExit, StopAsyncIteration): + pass + else: + raise RuntimeError("async_generator ignored GeneratorExit") + + def __del__(self): + if True: + if self._finalizer is not None: + self._finalizer(self) + else: + # Mimic the behavior of native generators on GC with no finalizer: + # throw in GeneratorExit, run for one turn, and complain if it didn't + # finish. + thrower = self.athrow(GeneratorExit) + try: + thrower.send(None) + except (GeneratorExit, StopAsyncIteration): + pass + except StopIteration: + raise RuntimeError("async_generator ignored GeneratorExit") + else: + raise RuntimeError( + "async_generator {!r} awaited during finalization; install " + "a finalization hook to support this, or wrap it in " + "'async with aclosing(...):'" + .format(self.ag_code.co_name) + ) + finally: + thrower.close() + + + +def async_generator(coroutine_maker): + @wraps(coroutine_maker) + def async_generator_maker(*args, **kwargs): + return AsyncGenerator(coroutine_maker(*args, **kwargs)) + + async_generator_maker._async_gen_function = id(async_generator_maker) + return async_generator_maker + + +def isasyncgen(obj): + if hasattr(inspect, "isasyncgen"): + if inspect.isasyncgen(obj): + return True + return isinstance(obj, AsyncGenerator) + + +def isasyncgenfunction(obj): + if hasattr(inspect, "isasyncgenfunction"): + if inspect.isasyncgenfunction(obj): + return True + return getattr(obj, "_async_gen_function", -1) == id(obj) +# Very much derived from the one in contextlib, by copy/pasting and then +# asyncifying everything. (Also I dropped the obscure support for using +# context managers as function decorators. It could be re-added; I just +# couldn't be bothered.) +# So this is a derivative work licensed under the PSF License, which requires +# the following notice: +# +# Copyright © 2001-2017 Python Software Foundation; All Rights Reserved +class _AsyncGeneratorContextManager: + def __init__(self, func, args, kwds): + self._func_name = func.__name__ + self._agen = func(*args, **kwds).__aiter__() + + async def __aenter__(self): + if sys.version_info < (3, 5, 2): + self._agen = await self._agen + try: + return await self._agen.asend(None) + except StopAsyncIteration: + raise RuntimeError("async generator didn't yield") from None + + async def __aexit__(self, type, value, traceback): + async with aclosing(self._agen): + if type is None: + try: + await self._agen.asend(None) + except StopAsyncIteration: + return False + else: + raise RuntimeError("async generator didn't stop") + else: + # It used to be possible to have type != None, value == None: + # https://bugs.python.org/issue1705170 + # but AFAICT this can't happen anymore. + assert value is not None + try: + await self._agen.athrow(type, value, traceback) + raise RuntimeError( + "async generator didn't stop after athrow()" + ) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception + # that was passed to throw(). This prevents a + # StopIteration raised inside the "with" statement from + # being suppressed. + return (exc is not value) + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27112) + if exc is value: + return False + # Likewise, avoid suppressing if a StopIteration exception + # was passed to throw() and later wrapped into a + # RuntimeError (see PEP 479). + if (isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value): + return False + raise + except: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise an + # exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so + # this fixes the impedance mismatch between the throw() + # protocol and the __exit__() protocol. + # + if sys.exc_info()[1] is value: + return False + raise + + def __enter__(self): + raise RuntimeError( + "use 'async with {func_name}(...)', not 'with {func_name}(...)'". + format(func_name=self._func_name) + ) + + def __exit__(self): # pragma: no cover + assert False, """Never called, but should be defined""" + + +def asynccontextmanager(func): + """Like @contextmanager, but async.""" + if not isasyncgenfunction(func): + raise TypeError( + "must be an async generator (native or from async_generator; " + "if using @async_generator then @acontextmanager must be on top." + ) + + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + + # A hint for sphinxcontrib-trio: + helper.__returns_acontextmanager__ = True + return helper + +__all__ = [ + "async_generator", + "yield_", + "yield_from_", + "aclosing", + "isasyncgen", + "isasyncgenfunction", + "asynccontextmanager", + "get_asyncgen_hooks", + "set_asyncgen_hooks", +] diff --git a/async_generator/metadata.txt b/async_generator/metadata.txt new file mode 100644 index 0000000000..e26221407a --- /dev/null +++ b/async_generator/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = module +version = 0.1.0 From 61d0f9a68febc819d660c019805be15c086ed971 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:21:39 +0100 Subject: [PATCH 07/65] sortedcontainers port --- sortedcontainers/metadata.txt | 3 + sortedcontainers/sortedcontainers/__init__.py | 52 + .../sortedcontainers/sorteddict.py | 720 +++++ .../sortedcontainers/sortedlist.py | 2481 +++++++++++++++++ .../sortedcontainers/sortedset.py | 323 +++ 5 files changed, 3579 insertions(+) create mode 100644 sortedcontainers/metadata.txt create mode 100644 sortedcontainers/sortedcontainers/__init__.py create mode 100644 sortedcontainers/sortedcontainers/sorteddict.py create mode 100644 sortedcontainers/sortedcontainers/sortedlist.py create mode 100644 sortedcontainers/sortedcontainers/sortedset.py diff --git a/sortedcontainers/metadata.txt b/sortedcontainers/metadata.txt new file mode 100644 index 0000000000..a8ff4f11f6 --- /dev/null +++ b/sortedcontainers/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = package +version = 0.1.0 diff --git a/sortedcontainers/sortedcontainers/__init__.py b/sortedcontainers/sortedcontainers/__init__.py new file mode 100644 index 0000000000..54b2bf67f3 --- /dev/null +++ b/sortedcontainers/sortedcontainers/__init__.py @@ -0,0 +1,52 @@ +"""Sorted Container Types: SortedList, SortedDict, SortedSet + +SortedContainers is an Apache2 licensed containers library, written in +pure-Python, and fast as C-extensions. + + +Python's standard library is great until you need a sorted collections +type. Many will attest that you can get really far without one, but the moment +you **really need** a sorted list, dict, or set, you're faced with a dozen +different implementations, most using C-extensions without great documentation +and benchmarking. + +In Python, we can do better. And we can do it in pure-Python! + +:: + + >>> from sortedcontainers import SortedList, SortedDict, SortedSet + >>> sl = SortedList(xrange(10000000)) + >>> 1234567 in sl + True + >>> sl[7654321] + 7654321 + >>> sl.add(1234567) + >>> sl.count(1234567) + 2 + >>> sl *= 3 + >>> len(sl) + 30000003 + +SortedContainers takes all of the work out of Python sorted types - making your +deployment and use of Python easy. There's no need to install a C compiler or +pre-build and distribute custom extensions. Performance is a feature and +testing has 100% coverage with unit tests and hours of stress. + +:copyright: (c) 2016 by Grant Jenks. +:license: Apache 2.0, see LICENSE for more details. + +""" + + +from .sortedlist import SortedList, SortedListWithKey +from .sortedset import SortedSet +from .sorteddict import SortedDict + +__all__ = ['SortedList', 'SortedSet', 'SortedDict', 'SortedListWithKey'] + +__title__ = 'sortedcontainers' +__version__ = '1.5.7' +__build__ = 0x010507 +__author__ = 'Grant Jenks' +__license__ = 'Apache 2.0' +__copyright__ = 'Copyright 2016 Grant Jenks' diff --git a/sortedcontainers/sortedcontainers/sorteddict.py b/sortedcontainers/sortedcontainers/sorteddict.py new file mode 100644 index 0000000000..d82aa86bca --- /dev/null +++ b/sortedcontainers/sortedcontainers/sorteddict.py @@ -0,0 +1,720 @@ +"""Sorted dictionary implementation. + +""" + +from collections import Set, Sequence +from sys import hexversion + +from .sortedlist import SortedList, recursive_repr, SortedListWithKey +from .sortedset import SortedSet + +NONE = object() + + +class _IlocWrapper(object): + "Positional indexing support for sorted dictionary objects." + # pylint: disable=protected-access, too-few-public-methods + def __init__(self, _dict): + self._dict = _dict + def __len__(self): + return len(self._dict) + def __getitem__(self, index): + """ + Very efficiently return the key at index *index* in iteration. Supports + negative indices and slice notation. Raises IndexError on invalid + *index*. + """ + return self._dict._list[index] + def __delitem__(self, index): + """ + Remove the ``sdict[sdict.iloc[index]]`` from *sdict*. Supports negative + indices and slice notation. Raises IndexError on invalid *index*. + """ + _dict = self._dict + _list = _dict._list + _delitem = _dict._delitem + + if isinstance(index, slice): + keys = _list[index] + del _list[index] + for key in keys: + _delitem(key) + else: + key = _list[index] + del _list[index] + _delitem(key) + + +class SortedDict: + """SortedDict provides the same methods as a dict. Additionally, SortedDict + efficiently maintains its keys in sorted order. Consequently, the keys + method will return the keys in sorted order, the popitem method will remove + the item with the highest key, etc. + + """ + def __init__(self, *args, **kwargs): + """SortedDict provides the same methods as a dict. Additionally, SortedDict + efficiently maintains its keys in sorted order. Consequently, the keys + method will return the keys in sorted order, the popitem method will + remove the item with the highest key, etc. + + An optional *key* argument defines a callable that, like the `key` + argument to Python's `sorted` function, extracts a comparison key from + each dict key. If no function is specified, the default compares the + dict keys directly. The `key` argument must be provided as a positional + argument and must come before all other arguments. + + An optional *load* argument defines the load factor of the internal list + used to maintain sort order. If present, this argument must come before + an iterable. The default load factor of '1000' works well for lists from + tens to tens of millions of elements. Good practice is to use a value + that is the cube root of the list size. With billions of elements, the + best load factor depends on your usage. It's best to leave the load + factor at the default until you start benchmarking. + + An optional *iterable* argument provides an initial series of items to + populate the SortedDict. Each item in the series must itself contain + two items. The first is used as a key in the new dictionary, and the + second as the key's value. If a given key is seen more than once, the + last value associated with it is retained in the new dictionary. + + If keyword arguments are given, the keywords themselves with their + associated values are added as items to the dictionary. If a key is + specified both in the positional argument and as a keyword argument, the + value associated with the keyword is retained in the dictionary. For + example, these all return a dictionary equal to ``{"one": 2, "two": + 3}``: + + * ``SortedDict(one=2, two=3)`` + * ``SortedDict({'one': 2, 'two': 3})`` + * ``SortedDict(zip(('one', 'two'), (2, 3)))`` + * ``SortedDict([['two', 3], ['one', 2]])`` + + The first example only works for keys that are valid Python + identifiers; the others work with any valid keys. + + """ + # pylint: disable=super-init-not-called, redefined-variable-type + if len(args) > 0 and (args[0] is None or callable(args[0])): + self._key = args[0] + args = args[1:] + else: + self._key = None + + if len(args) > 0 and isinstance(args[0], int): + self._load = args[0] + args = args[1:] + else: + self._load = 1000 + + if self._key is None: + self._list = SortedList(load=self._load) + else: + self._list = SortedListWithKey(key=self._key, load=self._load) + + # Cache function pointers to dict methods. + + self._dict = _dict = {} + self._clear = _dict.clear + self._delitem = _dict.__delitem__ + self._pop = _dict.pop + self._setdefault = _dict.setdefault + self._setitem = _dict.__setitem__ + self._dict_update = _dict.update + + # Cache function pointers to SortedList methods. + + _list = self._list + self._list_add = _list.add + self.bisect_left = _list.bisect_left + self.bisect = _list.bisect_right + self.bisect_right = _list.bisect_right + self._list_clear = _list.clear + self.index = _list.index + self._list_pop = _list.pop + self._list_remove = _list.remove + self._list_update = _list.update + self.irange = _list.irange + self.islice = _list.islice + + if self._key is not None: + self.bisect_key_left = _list.bisect_key_left + self.bisect_key_right = _list.bisect_key_right + self.bisect_key = _list.bisect_key + self.irange_key = _list.irange_key + + self.iloc = _IlocWrapper(self) + + self.update(*args, **kwargs) + + def _iter(self): + for k in self._dict: + yield k + + def clear(self): + """Remove all elements from the dictionary.""" + self._clear() + self._list_clear() + + def __delitem__(self, key): + """ + Remove ``d[key]`` from *d*. Raises a KeyError if *key* is not in the + dictionary. + """ + self._delitem(key) + self._list_remove(key) + + def __iter__(self): + """ + Return an iterator over the sorted keys of the dictionary. + + Iterating the Mapping while adding or deleting keys may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return iter(self._list) + + def __reversed__(self): + """ + Return a reversed iterator over the sorted keys of the dictionary. + + Iterating the Mapping while adding or deleting keys may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return reversed(self._list) + + def __setitem__(self, key, value): + """Set `d[key]` to *value*.""" + if key not in self: + self._list_add(key) + self._setitem(key, value) + + def copy(self): + """Return a shallow copy of the sorted dictionary.""" + return self.__class__(self._key, self._load, self._iteritems()) + + __copy__ = copy + + @classmethod + def fromkeys(cls, seq, value=None): + """ + Create a new dictionary with keys from *seq* and values set to *value*. + """ + return cls((key, value) for key in seq) + + def items(self): + """ + Return a list of the dictionary's items (``(key, value)`` pairs). + """ + return list(self._iteritems()) + + def iteritems(self): + """ + Return an iterator over the items (``(key, value)`` pairs). + + Iterating the Mapping while adding or deleting keys may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return iter((key, self[key]) for key in self._list) + + _iteritems = iteritems + + def keys(self): + """Return a SortedSet of the dictionary's keys.""" + return SortedSet(self._list, key=self._key, load=self._load) + + def iterkeys(self): + """ + Return an iterator over the sorted keys of the Mapping. + + Iterating the Mapping while adding or deleting keys may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return iter(self._list) + + def values(self): + """Return a list of the dictionary's values.""" + return list(self._itervalues()) + + def itervalues(self): + """ + Return an iterator over the values of the Mapping. + + Iterating the Mapping while adding or deleting keys may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return iter(self[key] for key in self._list) + + _itervalues = itervalues + + def pop(self, key, default=NONE): + """ + If *key* is in the dictionary, remove it and return its value, + else return *default*. If *default* is not given and *key* is not in + the dictionary, a KeyError is raised. + """ + if key in self: + self._list_remove(key) + return self._pop(key) + else: + if default is NONE: + raise KeyError(key) + else: + return default + + def popitem(self, last=True): + """ + Remove and return a ``(key, value)`` pair from the dictionary. If + last=True (default) then remove the *greatest* `key` from the + diciontary. Else, remove the *least* key from the dictionary. + + If the dictionary is empty, calling `popitem` raises a + KeyError`. + """ + if not len(self): + raise KeyError('popitem(): dictionary is empty') + + key = self._list_pop(-1 if last else 0) + value = self._pop(key) + + return (key, value) + + def peekitem(self, index=-1): + """Return (key, value) item pair at index. + + Unlike ``popitem``, the sorted dictionary is not modified. Index + defaults to -1, the last/greatest key in the dictionary. Specify + ``index=0`` to lookup the first/least key in the dictiony. + + If index is out of range, raise IndexError. + + """ + key = self._list[index] + return key, self[key] + + def setdefault(self, key, default=None): + """ + If *key* is in the dictionary, return its value. If not, insert *key* + with a value of *default* and return *default*. *default* defaults to + ``None``. + """ + if key in self: + return self[key] + else: + self._setitem(key, default) + self._list_add(key) + return default + + def __len__(self): + return len(self._dict) + + def update(self, *args, **kwargs): + """ + Update the dictionary with the key/value pairs from *other*, overwriting + existing keys. + + *update* accepts either another dictionary object or an iterable of + key/value pairs (as a tuple or other iterable of length two). If + keyword arguments are specified, the dictionary is then updated with + those key/value pairs: ``d.update(red=1, blue=2)``. + """ + if not len(self): + self._dict_update(*args, **kwargs) + self._list_update(self._iter()) + return + + if len(kwargs) == 0 and len(args) == 1 and isinstance(args[0], dict): + pairs = args[0] + else: + pairs = dict(*args, **kwargs) + + if (10 * len(pairs)) > len(self): + self._dict_update(pairs) + self._list_clear() + self._list_update(self._iter()) + else: + for key in pairs: + self[key] = pairs[key] + + _update = update + + if hexversion >= 0x02070000: + def viewkeys(self): + "Return ``KeysView`` of dictionary keys." + return KeysView(self) + + def viewvalues(self): + "Return ``ValuesView`` of dictionary values." + return ValuesView(self) + + def viewitems(self): + "Return ``ItemsView`` of dictionary (key, value) item pairs." + return ItemsView(self) + + def __reduce__(self): + return (self.__class__, (self._key, self._load, list(self._iteritems()))) + + @recursive_repr + def __repr__(self): + temp = '{0}({1}, {2}, {{{3}}})' + items = ', '.join('{0}: {1}'.format(repr(key), repr(self[key])) + for key in self._list) + return temp.format( + self.__class__.__name__, + repr(self._key), + repr(self._load), + items + ) + + def _check(self): + # pylint: disable=protected-access + self._list._check() + assert len(self) == len(self._list) + assert all(key in self for key in self._list) + + +class KeysView(Set, Sequence): + """ + A KeysView object is a dynamic view of the dictionary's keys, which + means that when the dictionary's keys change, the view reflects + those changes. + + The KeysView class implements the Set and Sequence Abstract Base Classes. + """ + if hexversion < 0x03000000: + def __init__(self, sorted_dict): + """ + Initialize a KeysView from a SortedDict container as *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._list = sorted_dict._list + self._view = sorted_dict._dict.viewkeys() + else: + def __init__(self, sorted_dict): + """ + Initialize a KeysView from a SortedDict container as *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._list = sorted_dict._list + self._view = sorted_dict._dict.keys() + def __len__(self): + """Return the number of entries in the dictionary.""" + return len(self._view) + def __contains__(self, key): + """ + Return True if and only if *key* is one of the underlying dictionary's + keys. + """ + return key in self._view + def __iter__(self): + """ + Return an iterable over the keys in the dictionary. Keys are iterated + over in their sorted order. + + Iterating views while adding or deleting entries in the dictionary may + raise a `RuntimeError` or fail to iterate over all entries. + """ + return iter(self._list) + def __getitem__(self, index): + """Return the key at position *index*.""" + return self._list[index] + def __reversed__(self): + """ + Return a reversed iterable over the keys in the dictionary. Keys are + iterated over in their reverse sort order. + + Iterating views while adding or deleting entries in the dictionary may + raise a RuntimeError or fail to iterate over all entries. + """ + return reversed(self._list) + def index(self, value, start=None, stop=None): + """ + Return the smallest *k* such that `keysview[k] == value` and `start <= k + < end`. Raises `KeyError` if *value* is not present. *stop* defaults + to the end of the set. *start* defaults to the beginning. Negative + indexes are supported, as for slice indices. + """ + # pylint: disable=arguments-differ + return self._list.index(value, start, stop) + def count(self, value): + """Return the number of occurrences of *value* in the set.""" + return 1 if value in self._view else 0 + def __eq__(self, that): + """Test set-like equality with *that*.""" + return self._view == that + def __ne__(self, that): + """Test set-like inequality with *that*.""" + return self._view != that + def __lt__(self, that): + """Test whether self is a proper subset of *that*.""" + return self._view < that + def __gt__(self, that): + """Test whether self is a proper superset of *that*.""" + return self._view > that + def __le__(self, that): + """Test whether self is contained within *that*.""" + return self._view <= that + def __ge__(self, that): + """Test whether *that* is contained within self.""" + return self._view >= that + def __and__(self, that): + """Return a SortedSet of the intersection of self and *that*.""" + return SortedSet(self._view & that) + def __or__(self, that): + """Return a SortedSet of the union of self and *that*.""" + return SortedSet(self._view | that) + def __sub__(self, that): + """Return a SortedSet of the difference of self and *that*.""" + return SortedSet(self._view - that) + def __xor__(self, that): + """Return a SortedSet of the symmetric difference of self and *that*.""" + return SortedSet(self._view ^ that) + if hexversion < 0x03000000: + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + return not any(key in self._list for key in that) + else: + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + return self._view.isdisjoint(that) + @recursive_repr + def __repr__(self): + return 'SortedDict_keys({0})'.format(repr(list(self))) + + +class ValuesView(Sequence): + """ + A ValuesView object is a dynamic view of the dictionary's values, which + means that when the dictionary's values change, the view reflects those + changes. + + The ValuesView class implements the Sequence Abstract Base Class. + """ + if hexversion < 0x03000000: + def __init__(self, sorted_dict): + """ + Initialize a ValuesView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.viewvalues() + else: + def __init__(self, sorted_dict): + """ + Initialize a ValuesView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.values() + def __len__(self): + """Return the number of entries in the dictionary.""" + return len(self._dict) + def __contains__(self, value): + """ + Return True if and only if *value* is in the underlying Mapping's + values. + """ + return value in self._view + def __iter__(self): + """ + Return an iterator over the values in the dictionary. Values are + iterated over in sorted order of the keys. + + Iterating views while adding or deleting entries in the dictionary may + raise a `RuntimeError` or fail to iterate over all entries. + """ + _dict = self._dict + return iter(_dict[key] for key in self._list) + def __getitem__(self, index): + """ + Efficiently return value at *index* in iteration. + + Supports slice notation and negative indexes. + """ + _dict, _list = self._dict, self._list + if isinstance(index, slice): + return [_dict[key] for key in _list[index]] + else: + return _dict[_list[index]] + def __reversed__(self): + """ + Return a reverse iterator over the values in the dictionary. Values are + iterated over in reverse sort order of the keys. + + Iterating views while adding or deleting entries in the dictionary may + raise a `RuntimeError` or fail to iterate over all entries. + """ + _dict = self._dict + return iter(_dict[key] for key in reversed(self._list)) + def index(self, value): + """ + Return index of *value* in self. + + Raises ValueError if *value* is not found. + """ + for idx, val in enumerate(self): + if value == val: + return idx + raise ValueError('{0} is not in dict'.format(repr(value))) + if hexversion < 0x03000000: + def count(self, value): + """Return the number of occurrences of *value* in self.""" + return sum(1 for val in self._dict.itervalues() if val == value) + else: + def count(self, value): + """Return the number of occurrences of *value* in self.""" + return sum(1 for val in self._dict.values() if val == value) + def __lt__(self, that): + raise TypeError + def __gt__(self, that): + raise TypeError + def __le__(self, that): + raise TypeError + def __ge__(self, that): + raise TypeError + def __and__(self, that): + raise TypeError + def __or__(self, that): + raise TypeError + def __sub__(self, that): + raise TypeError + def __xor__(self, that): + raise TypeError + @recursive_repr + def __repr__(self): + return 'SortedDict_values({0})'.format(repr(list(self))) + + +class ItemsView(Set, Sequence): + """ + An ItemsView object is a dynamic view of the dictionary's ``(key, + value)`` pairs, which means that when the dictionary changes, the + view reflects those changes. + + The ItemsView class implements the Set and Sequence Abstract Base Classes. + However, the set-like operations (``&``, ``|``, ``-``, ``^``) will only + operate correctly if all of the dictionary's values are hashable. + """ + if hexversion < 0x03000000: + def __init__(self, sorted_dict): + """ + Initialize an ItemsView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.viewitems() + else: + def __init__(self, sorted_dict): + """ + Initialize an ItemsView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.items() + def __len__(self): + """Return the number of entries in the dictionary.""" + return len(self._view) + def __contains__(self, key): + """ + Return True if and only if *key* is one of the underlying dictionary's + items. + """ + return key in self._view + def __iter__(self): + """ + Return an iterable over the items in the dictionary. Items are iterated + over in their sorted order. + + Iterating views while adding or deleting entries in the dictionary may + raise a `RuntimeError` or fail to iterate over all entries. + """ + _dict = self._dict + return iter((key, _dict[key]) for key in self._list) + def __getitem__(self, index): + """Return the item as position *index*.""" + _dict, _list = self._dict, self._list + if isinstance(index, slice): + return [(key, _dict[key]) for key in _list[index]] + else: + key = _list[index] + return (key, _dict[key]) + def __reversed__(self): + """ + Return a reversed iterable over the items in the dictionary. Items are + iterated over in their reverse sort order. + + Iterating views while adding or deleting entries in the dictionary may + raise a RuntimeError or fail to iterate over all entries. + """ + _dict = self._dict + return iter((key, _dict[key]) for key in reversed(self._list)) + def index(self, key, start=None, stop=None): + """ + Return the smallest *k* such that `itemssview[k] == key` and `start <= k + < end`. Raises `KeyError` if *key* is not present. *stop* defaults + to the end of the set. *start* defaults to the beginning. Negative + indexes are supported, as for slice indices. + """ + # pylint: disable=arguments-differ + temp, value = key + pos = self._list.index(temp, start, stop) + if value == self._dict[temp]: + return pos + else: + raise ValueError('{0} is not in dict'.format(repr(key))) + def count(self, item): + """Return the number of occurrences of *item* in the set.""" + key, value = item + return 1 if key in self._dict and self._dict[key] == value else 0 + def __eq__(self, that): + """Test set-like equality with *that*.""" + return self._view == that + def __ne__(self, that): + """Test set-like inequality with *that*.""" + return self._view != that + def __lt__(self, that): + """Test whether self is a proper subset of *that*.""" + return self._view < that + def __gt__(self, that): + """Test whether self is a proper superset of *that*.""" + return self._view > that + def __le__(self, that): + """Test whether self is contained within *that*.""" + return self._view <= that + def __ge__(self, that): + """Test whether *that* is contained within self.""" + return self._view >= that + def __and__(self, that): + """Return a SortedSet of the intersection of self and *that*.""" + return SortedSet(self._view & that) + def __or__(self, that): + """Return a SortedSet of the union of self and *that*.""" + return SortedSet(self._view | that) + def __sub__(self, that): + """Return a SortedSet of the difference of self and *that*.""" + return SortedSet(self._view - that) + def __xor__(self, that): + """Return a SortedSet of the symmetric difference of self and *that*.""" + return SortedSet(self._view ^ that) + if hexversion < 0x03000000: + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + _dict = self._dict + for key, value in that: + if key in _dict and _dict[key] == value: + return False + return True + else: + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + return self._view.isdisjoint(that) + @recursive_repr + def __repr__(self): + return 'SortedDict_items({0})'.format(repr(list(self))) diff --git a/sortedcontainers/sortedcontainers/sortedlist.py b/sortedcontainers/sortedcontainers/sortedlist.py new file mode 100644 index 0000000000..5d1245f972 --- /dev/null +++ b/sortedcontainers/sortedcontainers/sortedlist.py @@ -0,0 +1,2481 @@ +"""Sorted list implementation. + +""" +# pylint: disable=redefined-builtin, ungrouped-imports + +from __future__ import print_function + +from bisect import bisect_left, bisect_right, insort +from collections import Sequence, MutableSequence +from functools import wraps +from itertools import chain, repeat, starmap +from math import log as log_e +import operator as op +from operator import iadd, add +from sys import hexversion + +if hexversion < 0x03000000: + from itertools import izip as zip + from itertools import imap as map + try: + from thread import get_ident + except ImportError: + from dummy_thread import get_ident +else: + from functools import reduce + try: + from _thread import get_ident + except ImportError: + from _dummy_thread import get_ident # pylint: disable=import-error + +def recursive_repr(func): + """Decorator to prevent infinite repr recursion.""" + repr_running = set() + + @wraps(func) + def wrapper(self): + "Return ellipsis on recursive re-entry to function." + key = id(self), get_ident() + + if key in repr_running: + return '...' + + repr_running.add(key) + + try: + return func(self) + finally: + repr_running.discard(key) + + return wrapper + +class SortedList(MutableSequence): + """ + SortedList provides most of the same methods as a list but keeps the items + in sorted order. + """ + + def __init__(self, iterable=None, load=1000): + """ + SortedList provides most of the same methods as a list but keeps the + items in sorted order. + + An optional *iterable* provides an initial series of items to populate + the SortedList. + + An optional *load* specifies the load-factor of the list. The default + load factor of '1000' works well for lists from tens to tens of millions + of elements. Good practice is to use a value that is the cube root of + the list size. With billions of elements, the best load factor depends + on your usage. It's best to leave the load factor at the default until + you start benchmarking. + """ + self._len = 0 + self._lists = [] + self._maxes = [] + self._index = [] + self._load = load + self._twice = load * 2 + self._half = load >> 1 + self._offset = 0 + + if iterable is not None: + self._update(iterable) + + def __new__(cls, iterable=None, key=None, load=1000): + """ + SortedList provides most of the same methods as a list but keeps the + items in sorted order. + + An optional *iterable* provides an initial series of items to populate + the SortedList. + + An optional *key* argument will return an instance of subtype + SortedListWithKey. + + An optional *load* specifies the load-factor of the list. The default + load factor of '1000' works well for lists from tens to tens of millions + of elements. Good practice is to use a value that is the cube root of + the list size. With billions of elements, the best load factor depends + on your usage. It's best to leave the load factor at the default until + you start benchmarking. + """ + # pylint: disable=unused-argument + if key is None: + return object.__new__(cls) + else: + if cls is SortedList: + return object.__new__(SortedListWithKey) + else: + raise TypeError('inherit SortedListWithKey for key argument') + + def clear(self): + """Remove all the elements from the list.""" + self._len = 0 + del self._lists[:] + del self._maxes[:] + del self._index[:] + + _clear = clear + + def add(self, val): + """Add the element *val* to the list.""" + _lists = self._lists + _maxes = self._maxes + + if _maxes: + pos = bisect_right(_maxes, val) + + if pos == len(_maxes): + pos -= 1 + _lists[pos].append(val) + _maxes[pos] = val + else: + insort(_lists[pos], val) + + self._expand(pos) + else: + _lists.append([val]) + _maxes.append(val) + + self._len += 1 + + def _expand(self, pos): + """Splits sublists that are more than double the load level. + + Updates the index when the sublist length is less than double the load + level. This requires incrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see self._loc. + + """ + _lists = self._lists + _index = self._index + + if len(_lists[pos]) > self._twice: + _maxes = self._maxes + _load = self._load + + _lists_pos = _lists[pos] + half = _lists_pos[_load:] + del _lists_pos[_load:] + _maxes[pos] = _lists_pos[-1] + + _lists.insert(pos + 1, half) + _maxes.insert(pos + 1, half[-1]) + + del _index[:] + else: + if _index: + child = self._offset + pos + while child: + _index[child] += 1 + child = (child - 1) >> 1 + _index[0] += 1 + + def update(self, iterable): + """Update the list by adding all elements from *iterable*.""" + _lists = self._lists + _maxes = self._maxes + values = sorted(iterable) + + if _maxes: + if len(values) * 4 >= self._len: + values.extend(chain.from_iterable(_lists)) + values.sort() + self._clear() + else: + _add = self.add + for val in values: + _add(val) + return + + _load = self._load + _lists.extend(values[pos:(pos + _load)] + for pos in range(0, len(values), _load)) + _maxes.extend(sublist[-1] for sublist in _lists) + self._len = len(values) + del self._index[:] + + _update = update + + def __contains__(self, val): + """Return True if and only if *val* is an element in the list.""" + _maxes = self._maxes + + if not _maxes: + return False + + pos = bisect_left(_maxes, val) + + if pos == len(_maxes): + return False + + _lists = self._lists + idx = bisect_left(_lists[pos], val) + + return _lists[pos][idx] == val + + def discard(self, val): + """ + Remove the first occurrence of *val*. + + If *val* is not a member, does nothing. + """ + _maxes = self._maxes + + if not _maxes: + return + + pos = bisect_left(_maxes, val) + + if pos == len(_maxes): + return + + _lists = self._lists + idx = bisect_left(_lists[pos], val) + + if _lists[pos][idx] == val: + self._delete(pos, idx) + + def remove(self, val): + """ + Remove first occurrence of *val*. + + Raises ValueError if *val* is not present. + """ + _maxes = self._maxes + + if not _maxes: + raise ValueError('{0} not in list'.format(repr(val))) + + pos = bisect_left(_maxes, val) + + if pos == len(_maxes): + raise ValueError('{0} not in list'.format(repr(val))) + + _lists = self._lists + idx = bisect_left(_lists[pos], val) + + if _lists[pos][idx] == val: + self._delete(pos, idx) + else: + raise ValueError('{0} not in list'.format(repr(val))) + + def _delete(self, pos, idx): + """Delete the item at the given (pos, idx). + + Combines lists that are less than half the load level. + + Updates the index when the sublist length is more than half the load + level. This requires decrementing the nodes in a traversal from the leaf + node to the root. For an example traversal see self._loc. + """ + _lists = self._lists + _maxes = self._maxes + _index = self._index + + _lists_pos = _lists[pos] + + del _lists_pos[idx] + self._len -= 1 + + len_lists_pos = len(_lists_pos) + + if len_lists_pos > self._half: + + _maxes[pos] = _lists_pos[-1] + + if _index: + child = self._offset + pos + while child > 0: + _index[child] -= 1 + child = (child - 1) >> 1 + _index[0] -= 1 + + elif len(_lists) > 1: + + if not pos: + pos += 1 + + prev = pos - 1 + _lists[prev].extend(_lists[pos]) + _maxes[prev] = _lists[prev][-1] + + del _lists[pos] + del _maxes[pos] + del _index[:] + + self._expand(prev) + + elif len_lists_pos: + + _maxes[pos] = _lists_pos[-1] + + else: + + del _lists[pos] + del _maxes[pos] + del _index[:] + + def _loc(self, pos, idx): + """Convert an index pair (alpha, beta) into a single index that corresponds to + the position of the value in the sorted list. + + Most queries require the index be built. Details of the index are + described in self._build_index. + + Indexing requires traversing the tree from a leaf node to the root. The + parent of each node is easily computable at (pos - 1) // 2. + + Left-child nodes are always at odd indices and right-child nodes are + always at even indices. + + When traversing up from a right-child node, increment the total by the + left-child node. + + The final index is the sum from traversal and the index in the sublist. + + For example, using the index from self._build_index: + + _index = 14 5 9 3 2 4 5 + _offset = 3 + + Tree: + + 14 + 5 9 + 3 2 4 5 + + Converting index pair (2, 3) into a single index involves iterating like + so: + + 1. Starting at the leaf node: offset + alpha = 3 + 2 = 5. We identify + the node as a left-child node. At such nodes, we simply traverse to + the parent. + + 2. At node 9, position 2, we recognize the node as a right-child node + and accumulate the left-child in our total. Total is now 5 and we + traverse to the parent at position 0. + + 3. Iteration ends at the root. + + Computing the index is the sum of the total and beta: 5 + 3 = 8. + """ + if not pos: + return idx + + _index = self._index + + if not _index: + self._build_index() + + total = 0 + + # Increment pos to point in the index to len(self._lists[pos]). + + pos += self._offset + + # Iterate until reaching the root of the index tree at pos = 0. + + while pos: + + # Right-child nodes are at odd indices. At such indices + # account the total below the left child node. + + if not pos & 1: + total += _index[pos - 1] + + # Advance pos to the parent node. + + pos = (pos - 1) >> 1 + + return total + idx + + def _pos(self, idx): + """Convert an index into a pair (alpha, beta) that can be used to access + the corresponding _lists[alpha][beta] position. + + Most queries require the index be built. Details of the index are + described in self._build_index. + + Indexing requires traversing the tree to a leaf node. Each node has + two children which are easily computable. Given an index, pos, the + left-child is at pos * 2 + 1 and the right-child is at pos * 2 + 2. + + When the index is less than the left-child, traversal moves to the + left sub-tree. Otherwise, the index is decremented by the left-child + and traversal moves to the right sub-tree. + + At a child node, the indexing pair is computed from the relative + position of the child node as compared with the offset and the remaining + index. + + For example, using the index from self._build_index: + + _index = 14 5 9 3 2 4 5 + _offset = 3 + + Tree: + + 14 + 5 9 + 3 2 4 5 + + Indexing position 8 involves iterating like so: + + 1. Starting at the root, position 0, 8 is compared with the left-child + node (5) which it is greater than. When greater the index is + decremented and the position is updated to the right child node. + + 2. At node 9 with index 3, we again compare the index to the left-child + node with value 4. Because the index is the less than the left-child + node, we simply traverse to the left. + + 3. At node 4 with index 3, we recognize that we are at a leaf node and + stop iterating. + + 4. To compute the sublist index, we subtract the offset from the index + of the leaf node: 5 - 3 = 2. To compute the index in the sublist, we + simply use the index remaining from iteration. In this case, 3. + + The final index pair from our example is (2, 3) which corresponds to + index 8 in the sorted list. + """ + if idx < 0: + last_len = len(self._lists[-1]) + + if (-idx) <= last_len: + return len(self._lists) - 1, last_len + idx + + idx += self._len + + if idx < 0: + raise IndexError('list index out of range') + elif idx >= self._len: + raise IndexError('list index out of range') + + if idx < len(self._lists[0]): + return 0, idx + + _index = self._index + + if not _index: + self._build_index() + + pos = 0 + child = 1 + len_index = len(_index) + + while child < len_index: + index_child = _index[child] + + if idx < index_child: + pos = child + else: + idx -= index_child + pos = child + 1 + + child = (pos << 1) + 1 + + return (pos - self._offset, idx) + + def _build_index(self): + """Build an index for indexing the sorted list. + + Indexes are represented as binary trees in a dense array notation + similar to a binary heap. + + For example, given a _lists representation storing integers: + + [0]: 1 2 3 + [1]: 4 5 + [2]: 6 7 8 9 + [3]: 10 11 12 13 14 + + The first transformation maps the sub-lists by their length. The + first row of the index is the length of the sub-lists. + + [0]: 3 2 4 5 + + Each row after that is the sum of consecutive pairs of the previous row: + + [1]: 5 9 + [2]: 14 + + Finally, the index is built by concatenating these lists together: + + _index = 14 5 9 3 2 4 5 + + An offset storing the start of the first row is also stored: + + _offset = 3 + + When built, the index can be used for efficient indexing into the list. + See the comment and notes on self._pos for details. + """ + row0 = list(map(len, self._lists)) + + if len(row0) == 1: + self._index[:] = row0 + self._offset = 0 + return + + head = iter(row0) + tail = iter(head) + row1 = list(starmap(add, zip(head, tail))) + + if len(row0) & 1: + row1.append(row0[-1]) + + if len(row1) == 1: + self._index[:] = row1 + row0 + self._offset = 1 + return + + size = 2 ** (int(log_e(len(row1) - 1, 2)) + 1) + row1.extend(repeat(0, size - len(row1))) + tree = [row0, row1] + + while len(tree[-1]) > 1: + head = iter(tree[-1]) + tail = iter(head) + row = list(starmap(add, zip(head, tail))) + tree.append(row) + + reduce(iadd, reversed(tree), self._index) + self._offset = size * 2 - 1 + + def __delitem__(self, idx): + """Remove the element at *idx*. Supports slicing.""" + if isinstance(idx, slice): + start, stop, step = idx.indices(self._len) + + if step == 1 and start < stop: + if start == 0 and stop == self._len: + return self._clear() + elif self._len <= 8 * (stop - start): + values = self._getitem(slice(None, start)) + if stop < self._len: + values += self._getitem(slice(stop, None)) + self._clear() + return self._update(values) + + indices = range(start, stop, step) + + # Delete items from greatest index to least so + # that the indices remain valid throughout iteration. + + if step > 0: + indices = reversed(indices) + + _pos, _delete = self._pos, self._delete + + for index in indices: + pos, idx = _pos(index) + _delete(pos, idx) + else: + pos, idx = self._pos(idx) + self._delete(pos, idx) + + _delitem = __delitem__ + + def __getitem__(self, idx): + """Return the element at *idx*. Supports slicing.""" + _lists = self._lists + + if isinstance(idx, slice): + start, stop, step = idx.indices(self._len) + + if step == 1 and start < stop: + if start == 0 and stop == self._len: + return reduce(iadd, self._lists, []) + + start_pos, start_idx = self._pos(start) + + if stop == self._len: + stop_pos = len(_lists) - 1 + stop_idx = len(_lists[stop_pos]) + else: + stop_pos, stop_idx = self._pos(stop) + + if start_pos == stop_pos: + return _lists[start_pos][start_idx:stop_idx] + + prefix = _lists[start_pos][start_idx:] + middle = _lists[(start_pos + 1):stop_pos] + result = reduce(iadd, middle, prefix) + result += _lists[stop_pos][:stop_idx] + + return result + + if step == -1 and start > stop: + result = self._getitem(slice(stop + 1, start + 1)) + result.reverse() + return result + + # Return a list because a negative step could + # reverse the order of the items and this could + # be the desired behavior. + + indices = range(start, stop, step) + return list(self._getitem(index) for index in indices) + else: + if self._len: + if idx == 0: + return _lists[0][0] + elif idx == -1: + return _lists[-1][-1] + else: + raise IndexError('list index out of range') + + if 0 <= idx < len(_lists[0]): + return _lists[0][idx] + + len_last = len(_lists[-1]) + + if -len_last < idx < 0: + return _lists[-1][len_last + idx] + + pos, idx = self._pos(idx) + return _lists[pos][idx] + + _getitem = __getitem__ + + def _check_order(self, idx, val): + _len = self._len + _lists = self._lists + + pos, loc = self._pos(idx) + + if idx < 0: + idx += _len + + # Check that the inserted value is not less than the + # previous value. + + if idx > 0: + idx_prev = loc - 1 + pos_prev = pos + + if idx_prev < 0: + pos_prev -= 1 + idx_prev = len(_lists[pos_prev]) - 1 + + if _lists[pos_prev][idx_prev] > val: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + # Check that the inserted value is not greater than + # the previous value. + + if idx < (_len - 1): + idx_next = loc + 1 + pos_next = pos + + if idx_next == len(_lists[pos_next]): + pos_next += 1 + idx_next = 0 + + if _lists[pos_next][idx_next] < val: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + def __setitem__(self, index, value): + """Replace item at position *index* with *value*. + + Supports slice notation. Raises :exc:`ValueError` if the sort order + would be violated. When used with a slice and iterable, the + :exc:`ValueError` is raised before the list is mutated if the sort + order would be violated by the operation. + + """ + _lists = self._lists + _maxes = self._maxes + _check_order = self._check_order + _pos = self._pos + + if isinstance(index, slice): + _len = self._len + start, stop, step = index.indices(_len) + indices = range(start, stop, step) + + # Copy value to avoid aliasing issues with self and cases where an + # iterator is given. + + values = tuple(value) + + if step != 1: + if len(values) != len(indices): + raise ValueError( + 'attempt to assign sequence of size %s' + ' to extended slice of size %s' + % (len(values), len(indices))) + + # Keep a log of values that are set so that we can + # roll back changes if ordering is violated. + + log = [] + _append = log.append + + for idx, val in zip(indices, values): + pos, loc = _pos(idx) + _append((idx, _lists[pos][loc], val)) + _lists[pos][loc] = val + if len(_lists[pos]) == (loc + 1): + _maxes[pos] = val + + try: + # Validate ordering of new values. + + for idx, _, newval in log: + _check_order(idx, newval) + + except ValueError: + + # Roll back changes from log. + + for idx, oldval, _ in log: + pos, loc = _pos(idx) + _lists[pos][loc] = oldval + if len(_lists[pos]) == (loc + 1): + _maxes[pos] = oldval + + raise + else: + if start == 0 and stop == _len: + self._clear() + return self._update(values) + + if stop < start: + # When calculating indices, stop may be less than start. + # For example: ...[5:3:1] results in slice(5, 3, 1) which + # is a valid but not useful stop index. + stop = start + + if values: + + # Check that given values are ordered properly. + + alphas = iter(values) + betas = iter(values) + next(betas) + pairs = zip(alphas, betas) + + if not all(alpha <= beta for alpha, beta in pairs): + raise ValueError('given values not in sort order') + + # Check ordering in context of sorted list. + + if start and self._getitem(start - 1) > values[0]: + message = '%s not in sort order at index %s' % ( + repr(values[0]), start) + raise ValueError(message) + + if stop != _len and self._getitem(stop) < values[-1]: + message = '%s not in sort order at index %s' % ( + repr(values[-1]), stop) + raise ValueError(message) + + # Delete the existing values. + + self._delitem(index) + + # Insert the new values. + + _insert = self.insert + for idx, val in enumerate(values): + _insert(start + idx, val) + else: + pos, loc = _pos(index) + _check_order(index, value) + _lists[pos][loc] = value + if len(_lists[pos]) == (loc + 1): + _maxes[pos] = value + + def __iter__(self): + """ + Return an iterator over the Sequence. + + Iterating the Sequence while adding or deleting values may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return chain(self._lists) + + def __reversed__(self): + """ + Return an iterator to traverse the Sequence in reverse. + + Iterating the Sequence while adding or deleting values may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return chain(map(reversed, reversed(self._lists))) + + def islice(self, start=None, stop=None, reverse=False): + """ + Returns an iterator that slices `self` from `start` to `stop` index, + inclusive and exclusive respectively. + + When `reverse` is `True`, values are yielded from the iterator in + reverse order. + + Both `start` and `stop` default to `None` which is automatically + inclusive of the beginning and end. + """ + _len = self._len + + if not _len: + return iter(()) + + start, stop, _ = slice(start, stop).indices(self._len) + + if start >= stop: + return iter(()) + + _pos = self._pos + + min_pos, min_idx = _pos(start) + + if stop == _len: + max_pos = len(self._lists) - 1 + max_idx = len(self._lists[-1]) + else: + max_pos, max_idx = _pos(stop) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + def _islice(self, min_pos, min_idx, max_pos, max_idx, reverse): + """ + Returns an iterator that slices `self` using two index pairs, + `(min_pos, min_idx)` and `(max_pos, max_idx)`; the first inclusive + and the latter exclusive. See `_pos` for details on how an index + is converted to an index pair. + + When `reverse` is `True`, values are yielded from the iterator in + reverse order. + """ + _lists = self._lists + + if min_pos > max_pos: + return iter(()) + elif min_pos == max_pos and not reverse: + return iter(_lists[min_pos][min_idx:max_idx]) + elif min_pos == max_pos and reverse: + return reversed(_lists[min_pos][min_idx:max_idx]) + elif min_pos + 1 == max_pos and not reverse: + return chain(_lists[min_pos][min_idx:], _lists[max_pos][:max_idx]) + elif min_pos + 1 == max_pos and reverse: + return chain( + reversed(_lists[max_pos][:max_idx]), + reversed(_lists[min_pos][min_idx:]), + ) + elif not reverse: + return chain( + _lists[min_pos][min_idx:], + chain.from_iterable(_lists[(min_pos + 1):max_pos]), + _lists[max_pos][:max_idx], + ) + else: + temp = map(reversed, reversed(_lists[(min_pos + 1):max_pos])) + return chain( + reversed(_lists[max_pos][:max_idx]), + chain.from_iterable(temp), + reversed(_lists[min_pos][min_idx:]), + ) + + def irange(self, minimum=None, maximum=None, inclusive=(True, True), + reverse=False): + """ + Create an iterator of values between `minimum` and `maximum`. + + `inclusive` is a pair of booleans that indicates whether the minimum + and maximum ought to be included in the range, respectively. The + default is (True, True) such that the range is inclusive of both + minimum and maximum. + + Both `minimum` and `maximum` default to `None` which is automatically + inclusive of the start and end of the list, respectively. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + """ + _maxes = self._maxes + + if not _maxes: + return iter(()) + + _lists = self._lists + + # Calculate the minimum (pos, idx) pair. By default this location + # will be inclusive in our calculation. + + if minimum is None: + min_pos = 0 + min_idx = 0 + else: + if inclusive[0]: + min_pos = bisect_left(_maxes, minimum) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_left(_lists[min_pos], minimum) + else: + min_pos = bisect_right(_maxes, minimum) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_right(_lists[min_pos], minimum) + + # Calculate the maximum (pos, idx) pair. By default this location + # will be exclusive in our calculation. + + if maximum is None: + max_pos = len(_maxes) - 1 + max_idx = len(_lists[max_pos]) + else: + if inclusive[1]: + max_pos = bisect_right(_maxes, maximum) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_lists[max_pos]) + else: + max_idx = bisect_right(_lists[max_pos], maximum) + else: + max_pos = bisect_left(_maxes, maximum) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_lists[max_pos]) + else: + max_idx = bisect_left(_lists[max_pos], maximum) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + def __len__(self): + """Return the number of elements in the list.""" + return self._len + + def bisect_left(self, val): + """ + Similar to the *bisect* module in the standard library, this returns an + appropriate index to insert *val*. If *val* is already present, the + insertion point will be before (to the left of) any existing entries. + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_left(_maxes, val) + + if pos == len(_maxes): + return self._len + + idx = bisect_left(self._lists[pos], val) + + return self._loc(pos, idx) + + def bisect_right(self, val): + """ + Same as *bisect_left*, but if *val* is already present, the insertion + point will be after (to the right of) any existing entries. + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_right(_maxes, val) + + if pos == len(_maxes): + return self._len + + idx = bisect_right(self._lists[pos], val) + + return self._loc(pos, idx) + + bisect = bisect_right + _bisect_right = bisect_right + + def count(self, val): + """Return the number of occurrences of *val* in the list.""" + _maxes = self._maxes + + if not _maxes: + return 0 + + pos_left = bisect_left(_maxes, val) + + if pos_left == len(_maxes): + return 0 + + _lists = self._lists + idx_left = bisect_left(_lists[pos_left], val) + pos_right = bisect_right(_maxes, val) + + if pos_right == len(_maxes): + return self._len - self._loc(pos_left, idx_left) + + idx_right = bisect_right(_lists[pos_right], val) + + if pos_left == pos_right: + return idx_right - idx_left + + right = self._loc(pos_right, idx_right) + left = self._loc(pos_left, idx_left) + + return right - left + + def copy(self): + """Return a shallow copy of the sorted list.""" + return self.__class__(self, load=self._load) + + __copy__ = copy + + def append(self, val): + """ + Append the element *val* to the list. Raises a ValueError if the *val* + would violate the sort order. + """ + _lists = self._lists + _maxes = self._maxes + + if not _maxes: + _maxes.append(val) + _lists.append([val]) + self._len = 1 + return + + pos = len(_lists) - 1 + + if val < _lists[pos][-1]: + msg = '{0} not in sort order at index {1}'.format(repr(val), self._len) + raise ValueError(msg) + + _maxes[pos] = val + _lists[pos].append(val) + self._len += 1 + self._expand(pos) + + def extend(self, values): + """ + Extend the list by appending all elements from the *values*. Raises a + ValueError if the sort order would be violated. + """ + _lists = self._lists + _maxes = self._maxes + _load = self._load + + if not isinstance(values, list): + values = list(values) + + if not values: + return + + if any(values[pos - 1] > values[pos] + for pos in range(1, len(values))): + raise ValueError('given sequence not in sort order') + + offset = 0 + + if _maxes: + if values[0] < _lists[-1][-1]: + msg = '{0} not in sort order at index {1}'.format(repr(values[0]), self._len) + raise ValueError(msg) + + if len(_lists[-1]) < self._half: + _lists[-1].extend(values[:_load]) + _maxes[-1] = _lists[-1][-1] + offset = _load + + len_lists = len(_lists) + + for idx in range(offset, len(values), _load): + _lists.append(values[idx:(idx + _load)]) + _maxes.append(_lists[-1][-1]) + + _index = self._index + + if len_lists == len(_lists): + len_index = len(_index) + if len_index > 0: + len_values = len(values) + child = len_index - 1 + while child: + _index[child] += len_values + child = (child - 1) >> 1 + _index[0] += len_values + else: + del _index[:] + + self._len += len(values) + + def insert(self, idx, val): + """ + Insert the element *val* into the list at *idx*. Raises a ValueError if + the *val* at *idx* would violate the sort order. + """ + _len = self._len + _lists = self._lists + _maxes = self._maxes + + if idx < 0: + idx += _len + if idx < 0: + idx = 0 + if idx > _len: + idx = _len + + if not _maxes: + # The idx must be zero by the inequalities above. + _maxes.append(val) + _lists.append([val]) + self._len = 1 + return + + if not idx: + if val > _lists[0][0]: + msg = '{0} not in sort order at index {1}'.format(repr(val), 0) + raise ValueError(msg) + else: + _lists[0].insert(0, val) + self._expand(0) + self._len += 1 + return + + if idx == _len: + pos = len(_lists) - 1 + if _lists[pos][-1] > val: + msg = '{0} not in sort order at index {1}'.format(repr(val), _len) + raise ValueError(msg) + else: + _lists[pos].append(val) + _maxes[pos] = _lists[pos][-1] + self._expand(pos) + self._len += 1 + return + + pos, idx = self._pos(idx) + idx_before = idx - 1 + if idx_before < 0: + pos_before = pos - 1 + idx_before = len(_lists[pos_before]) - 1 + else: + pos_before = pos + + before = _lists[pos_before][idx_before] + if before <= val <= _lists[pos][idx]: + _lists[pos].insert(idx, val) + self._expand(pos) + self._len += 1 + else: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + def pop(self, idx=-1): + """ + Remove and return item at *idx* (default last). Raises IndexError if + list is empty or index is out of range. Negative indices are supported, + as for slice indices. + """ + if not self._len: + raise IndexError('pop index out of range') + + _lists = self._lists + + if idx == 0: + val = _lists[0][0] + self._delete(0, 0) + return val + + if idx == -1: + pos = len(_lists) - 1 + loc = len(_lists[pos]) - 1 + val = _lists[pos][loc] + self._delete(pos, loc) + return val + + if 0 <= idx < len(_lists[0]): + val = _lists[0][idx] + self._delete(0, idx) + return val + + len_last = len(_lists[-1]) + + if -len_last < idx < 0: + pos = len(_lists) - 1 + loc = len_last + idx + val = _lists[pos][loc] + self._delete(pos, loc) + return val + + pos, idx = self._pos(idx) + val = _lists[pos][idx] + self._delete(pos, idx) + + return val + + def index(self, val, start=None, stop=None): + """ + Return the smallest *k* such that L[k] == val and i <= k < j`. Raises + ValueError if *val* is not present. *stop* defaults to the end of the + list. *start* defaults to the beginning. Negative indices are supported, + as for slice indices. + """ + # pylint: disable=arguments-differ + _len = self._len + + if not _len: + raise ValueError('{0} is not in list'.format(repr(val))) + + if start is None: + start = 0 + if start < 0: + start += _len + if start < 0: + start = 0 + + if stop is None: + stop = _len + if stop < 0: + stop += _len + if stop > _len: + stop = _len + + if stop <= start: + raise ValueError('{0} is not in list'.format(repr(val))) + + _maxes = self._maxes + pos_left = bisect_left(_maxes, val) + + if pos_left == len(_maxes): + raise ValueError('{0} is not in list'.format(repr(val))) + + _lists = self._lists + idx_left = bisect_left(_lists[pos_left], val) + + if _lists[pos_left][idx_left] != val: + raise ValueError('{0} is not in list'.format(repr(val))) + + stop -= 1 + left = self._loc(pos_left, idx_left) + + if start <= left: + if left <= stop: + return left + else: + right = self._bisect_right(val) - 1 + + if start <= right: + return start + + raise ValueError('{0} is not in list'.format(repr(val))) + + def __add__(self, that): + """ + Return a new sorted list containing all the elements in *self* and + *that*. Elements in *that* do not need to be properly ordered with + respect to *self*. + """ + values = reduce(iadd, self._lists, []) + values.extend(that) + return self.__class__(values, load=self._load) + + def __iadd__(self, that): + """ + Update *self* to include all values in *that*. Elements in *that* do not + need to be properly ordered with respect to *self*. + """ + self._update(that) + return self + + def __mul__(self, that): + """ + Return a new sorted list containing *that* shallow copies of each item + in SortedList. + """ + values = reduce(iadd, self._lists, []) * that + return self.__class__(values, load=self._load) + + def __imul__(self, that): + """ + Increase the length of the list by appending *that* shallow copies of + each item. + """ + values = reduce(iadd, self._lists, []) * that + self._clear() + self._update(values) + return self + + def _make_cmp(self, seq_op, doc): + "Make comparator method." + def comparer(self, that): + "Compare method for sorted list and sequence." + # pylint: disable=protected-access + if not isinstance(that, Sequence): + return NotImplemented + + self_len = self._len + len_that = len(that) + + if self_len != len_that: + if seq_op is op.eq: + return False + if seq_op is op.ne: + return True + + for alpha, beta in zip(self, that): + if alpha != beta: + return seq_op(alpha, beta) + + return seq_op(self_len, len_that) + + return comparer + + __eq__ = _make_cmp(None, op.eq, 'equal to') + __ne__ = _make_cmp(None, op.ne, 'not equal to') + __lt__ = _make_cmp(None, op.lt, 'less than') + __gt__ = _make_cmp(None, op.gt, 'greater than') + __le__ = _make_cmp(None, op.le, 'less than or equal to') + __ge__ = _make_cmp(None, op.ge, 'greater than or equal to') + + @recursive_repr + def __repr__(self): + """Return string representation of sequence.""" + temp = '{0}({1}, load={2})' + return temp.format( + self.__class__.__name__, + repr(list(self)), + repr(self._load) + ) + + def _check(self): + try: + # Check load parameters. + + assert self._load >= 4 + assert self._half == (self._load >> 1) + assert self._twice == (self._load * 2) + + # Check empty sorted list case. + + if self._maxes == []: + assert self._lists == [] + return + + assert len(self._maxes) > 0 and len(self._lists) > 0 + + # Check all sublists are sorted. + + assert all(sublist[pos - 1] <= sublist[pos] + for sublist in self._lists + for pos in range(1, len(sublist))) + + # Check beginning/end of sublists are sorted. + + for pos in range(1, len(self._lists)): + assert self._lists[pos - 1][-1] <= self._lists[pos][0] + + # Check length of _maxes and _lists match. + + assert len(self._maxes) == len(self._lists) + + # Check _maxes is a map of _lists. + + assert all(self._maxes[pos] == self._lists[pos][-1] + for pos in range(len(self._maxes))) + + # Check load level is less than _twice. + + assert all(len(sublist) <= self._twice for sublist in self._lists) + + # Check load level is greater than _half for all + # but the last sublist. + + assert all(len(self._lists[pos]) >= self._half + for pos in range(0, len(self._lists) - 1)) + + # Check length. + + assert self._len == sum(len(sublist) for sublist in self._lists) + + # Check index. + + if len(self._index): + assert len(self._index) == self._offset + len(self._lists) + assert self._len == self._index[0] + + def test_offset_pos(pos): + "Test positional indexing offset." + from_index = self._index[self._offset + pos] + return from_index == len(self._lists[pos]) + + assert all(test_offset_pos(pos) + for pos in range(len(self._lists))) + + for pos in range(self._offset): + child = (pos << 1) + 1 + if child >= len(self._index): + assert self._index[pos] == 0 + elif child + 1 == len(self._index): + assert self._index[pos] == self._index[child] + else: + child_sum = self._index[child] + self._index[child + 1] + assert self._index[pos] == child_sum + + except: + import sys + import traceback + + traceback.print_exc(file=sys.stdout) + + print('len', self._len) + print('load', self._load, self._half, self._twice) + print('offset', self._offset) + print('len_index', len(self._index)) + print('index', self._index) + print('len_maxes', len(self._maxes)) + print('maxes', self._maxes) + print('len_lists', len(self._lists)) + print('lists', self._lists) + + raise + +def identity(value): + "Identity function." + return value + +class SortedListWithKey(SortedList): + """ + SortedListWithKey provides most of the same methods as a list but keeps + the items in sorted order. + """ + + def __init__(self, iterable=None, key=identity, load=1000): + """SortedListWithKey provides most of the same methods as list but keeps the + items in sorted order. + + An optional *iterable* provides an initial series of items to populate + the SortedListWithKey. + + An optional *key* argument defines a callable that, like the `key` + argument to Python's `sorted` function, extracts a comparison key from + each element. The default is the identity function. + + An optional *load* specifies the load-factor of the list. The default + load factor of '1000' works well for lists from tens to tens of millions + of elements. Good practice is to use a value that is the cube root of + the list size. With billions of elements, the best load factor depends + on your usage. It's best to leave the load factor at the default until + you start benchmarking. + + """ + # pylint: disable=super-init-not-called + self._len = 0 + self._lists = [] + self._keys = [] + self._maxes = [] + self._index = [] + self._key = key + self._load = load + self._twice = load * 2 + self._half = load >> 1 + self._offset = 0 + + if iterable is not None: + self._update(iterable) + + def __new__(cls, iterable=None, key=identity, load=1000): + return object.__new__(cls) + + def clear(self): + """Remove all the elements from the list.""" + self._len = 0 + del self._lists[:] + del self._keys[:] + del self._maxes[:] + del self._index[:] + + _clear = clear + + def add(self, val): + """Add the element *val* to the list.""" + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + + key = self._key(val) + + if _maxes: + pos = bisect_right(_maxes, key) + + if pos == len(_maxes): + pos -= 1 + _lists[pos].append(val) + _keys[pos].append(key) + _maxes[pos] = key + else: + idx = bisect_right(_keys[pos], key) + _lists[pos].insert(idx, val) + _keys[pos].insert(idx, key) + + self._expand(pos) + else: + _lists.append([val]) + _keys.append([key]) + _maxes.append(key) + + self._len += 1 + + def _expand(self, pos): + """Splits sublists that are more than double the load level. + + Updates the index when the sublist length is less than double the load + level. This requires incrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see self._loc. + + """ + _lists = self._lists + _keys = self._keys + _index = self._index + + if len(_keys[pos]) > self._twice: + _maxes = self._maxes + _load = self._load + + _lists_pos = _lists[pos] + _keys_pos = _keys[pos] + half = _lists_pos[_load:] + half_keys = _keys_pos[_load:] + del _lists_pos[_load:] + del _keys_pos[_load:] + _maxes[pos] = _keys_pos[-1] + + _lists.insert(pos + 1, half) + _keys.insert(pos + 1, half_keys) + _maxes.insert(pos + 1, half_keys[-1]) + + del _index[:] + else: + if _index: + child = self._offset + pos + while child: + _index[child] += 1 + child = (child - 1) >> 1 + _index[0] += 1 + + def update(self, iterable): + """Update the list by adding all elements from *iterable*.""" + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + values = sorted(iterable, key=self._key) + + if _maxes: + if len(values) * 4 >= self._len: + values.extend(chain.from_iterable(_lists)) + values.sort(key=self._key) + self._clear() + else: + _add = self.add + for val in values: + _add(val) + return + + _load = self._load + _lists.extend(values[pos:(pos + _load)] + for pos in range(0, len(values), _load)) + _keys.extend(list(map(self._key, _list)) for _list in _lists) + _maxes.extend(sublist[-1] for sublist in _keys) + self._len = len(values) + del self._index[:] + + _update = update + + def __contains__(self, val): + """Return True if and only if *val* is an element in the list.""" + _maxes = self._maxes + + if not _maxes: + return False + + key = self._key(val) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return False + + _lists = self._lists + _keys = self._keys + + idx = bisect_left(_keys[pos], key) + + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return False + if _lists[pos][idx] == val: + return True + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return False + len_sublist = len(_keys[pos]) + idx = 0 + + def discard(self, val): + """ + Remove the first occurrence of *val*. + + If *val* is not a member, does nothing. + """ + _maxes = self._maxes + + if not _maxes: + return + + key = self._key(val) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return + if _lists[pos][idx] == val: + self._delete(pos, idx) + return + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return + len_sublist = len(_keys[pos]) + idx = 0 + + def remove(self, val): + """ + Remove first occurrence of *val*. + + Raises ValueError if *val* is not present. + """ + _maxes = self._maxes + + if not _maxes: + raise ValueError('{0} not in list'.format(repr(val))) + + key = self._key(val) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + raise ValueError('{0} not in list'.format(repr(val))) + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + raise ValueError('{0} not in list'.format(repr(val))) + if _lists[pos][idx] == val: + self._delete(pos, idx) + return + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + raise ValueError('{0} not in list'.format(repr(val))) + len_sublist = len(_keys[pos]) + idx = 0 + + def _delete(self, pos, idx): + """ + Delete the item at the given (pos, idx). + + Combines lists that are less than half the load level. + + Updates the index when the sublist length is more than half the load + level. This requires decrementing the nodes in a traversal from the leaf + node to the root. For an example traversal see self._loc. + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + _index = self._index + keys_pos = _keys[pos] + lists_pos = _lists[pos] + + del keys_pos[idx] + del lists_pos[idx] + self._len -= 1 + + len_keys_pos = len(keys_pos) + + if len_keys_pos > self._half: + + _maxes[pos] = keys_pos[-1] + + if _index: + child = self._offset + pos + while child > 0: + _index[child] -= 1 + child = (child - 1) >> 1 + _index[0] -= 1 + + elif len(_keys) > 1: + + if not pos: + pos += 1 + + prev = pos - 1 + _keys[prev].extend(_keys[pos]) + _lists[prev].extend(_lists[pos]) + _maxes[prev] = _keys[prev][-1] + + del _lists[pos] + del _keys[pos] + del _maxes[pos] + del _index[:] + + self._expand(prev) + + elif len_keys_pos: + + _maxes[pos] = keys_pos[-1] + + else: + + del _lists[pos] + del _keys[pos] + del _maxes[pos] + del _index[:] + + def _check_order(self, idx, key, val): + # pylint: disable=arguments-differ + _len = self._len + _keys = self._keys + + pos, loc = self._pos(idx) + + if idx < 0: + idx += _len + + # Check that the inserted value is not less than the + # previous value. + + if idx > 0: + idx_prev = loc - 1 + pos_prev = pos + + if idx_prev < 0: + pos_prev -= 1 + idx_prev = len(_keys[pos_prev]) - 1 + + if _keys[pos_prev][idx_prev] > key: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + # Check that the inserted value is not greater than + # the previous value. + + if idx < (_len - 1): + idx_next = loc + 1 + pos_next = pos + + if idx_next == len(_keys[pos_next]): + pos_next += 1 + idx_next = 0 + + if _keys[pos_next][idx_next] < key: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + def __setitem__(self, index, value): + """ + Replace the item at position *index* with *value*. + + Supports slice notation. Raises a :exc:`ValueError` if the sort order + would be violated. When used with a slice and iterable, the + :exc:`ValueError` is raised before the list is mutated if the sort order + would be violated by the operation. + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + _check_order = self._check_order + _pos = self._pos + + if isinstance(index, slice): + start, stop, step = index.indices(self._len) + indices = range(start, stop, step) + + if step != 1: + if not hasattr(value, '__len__'): + value = list(value) + + indices = list(indices) + + if len(value) != len(indices): + raise ValueError( + 'attempt to assign sequence of size {0}' + ' to extended slice of size {1}' + .format(len(value), len(indices))) + + # Keep a log of values that are set so that we can + # roll back changes if ordering is violated. + + log = [] + _append = log.append + + for idx, val in zip(indices, value): + pos, loc = _pos(idx) + key = self._key(val) + _append((idx, _keys[pos][loc], key, _lists[pos][loc], val)) + _keys[pos][loc] = key + _lists[pos][loc] = val + if len(_keys[pos]) == (loc + 1): + _maxes[pos] = key + + try: + # Validate ordering of new values. + + for idx, oldkey, newkey, oldval, newval in log: + _check_order(idx, newkey, newval) + + except ValueError: + + # Roll back changes from log. + + for idx, oldkey, newkey, oldval, newval in log: + pos, loc = _pos(idx) + _keys[pos][loc] = oldkey + _lists[pos][loc] = oldval + if len(_keys[pos]) == (loc + 1): + _maxes[pos] = oldkey + + raise + else: + if start == 0 and stop == self._len: + self._clear() + return self._update(value) + + # Test ordering using indexing. If the given value + # isn't a Sequence, convert it to a tuple. + + if not isinstance(value, Sequence): + value = tuple(value) # pylint: disable=redefined-variable-type + + # Check that the given values are ordered properly. + + keys = tuple(map(self._key, value)) + iterator = range(1, len(keys)) + + if not all(keys[pos - 1] <= keys[pos] for pos in iterator): + raise ValueError('given sequence not in sort order') + + # Check ordering in context of sorted list. + + if not start or not len(value): + # Nothing to check on the lhs. + pass + else: + pos, loc = _pos(start - 1) + if _keys[pos][loc] > keys[0]: + msg = '{0} not in sort order at index {1}'.format(repr(value[0]), start) + raise ValueError(msg) + + if stop == len(self) or not len(value): + # Nothing to check on the rhs. + pass + else: + # "stop" is exclusive so we don't need + # to add one for the index. + pos, loc = _pos(stop) + if _keys[pos][loc] < keys[-1]: + msg = '{0} not in sort order at index {1}'.format(repr(value[-1]), stop) + raise ValueError(msg) + + # Delete the existing values. + + self._delitem(index) + + # Insert the new values. + + _insert = self.insert + for idx, val in enumerate(value): + _insert(start + idx, val) + else: + pos, loc = _pos(index) + key = self._key(value) + _check_order(index, key, value) + _lists[pos][loc] = value + _keys[pos][loc] = key + if len(_lists[pos]) == (loc + 1): + _maxes[pos] = key + + def irange(self, minimum=None, maximum=None, inclusive=(True, True), + reverse=False): + """ + Create an iterator of values between `minimum` and `maximum`. + + `inclusive` is a pair of booleans that indicates whether the minimum + and maximum ought to be included in the range, respectively. The + default is (True, True) such that the range is inclusive of both + minimum and maximum. + + Both `minimum` and `maximum` default to `None` which is automatically + inclusive of the start and end of the list, respectively. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + """ + minimum = self._key(minimum) if minimum is not None else None + maximum = self._key(maximum) if maximum is not None else None + return self._irange_key( + min_key=minimum, max_key=maximum, + inclusive=inclusive, reverse=reverse, + ) + + def irange_key(self, min_key=None, max_key=None, inclusive=(True, True), + reverse=False): + """ + Create an iterator of values between `min_key` and `max_key`. + + `inclusive` is a pair of booleans that indicates whether the min_key + and max_key ought to be included in the range, respectively. The + default is (True, True) such that the range is inclusive of both + `min_key` and `max_key`. + + Both `min_key` and `max_key` default to `None` which is automatically + inclusive of the start and end of the list, respectively. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + """ + _maxes = self._maxes + + if not _maxes: + return iter(()) + + _keys = self._keys + + # Calculate the minimum (pos, idx) pair. By default this location + # will be inclusive in our calculation. + + if min_key is None: + min_pos = 0 + min_idx = 0 + else: + if inclusive[0]: + min_pos = bisect_left(_maxes, min_key) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_left(_keys[min_pos], min_key) + else: + min_pos = bisect_right(_maxes, min_key) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_right(_keys[min_pos], min_key) + + # Calculate the maximum (pos, idx) pair. By default this location + # will be exclusive in our calculation. + + if max_key is None: + max_pos = len(_maxes) - 1 + max_idx = len(_keys[max_pos]) + else: + if inclusive[1]: + max_pos = bisect_right(_maxes, max_key) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_keys[max_pos]) + else: + max_idx = bisect_right(_keys[max_pos], max_key) + else: + max_pos = bisect_left(_maxes, max_key) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_keys[max_pos]) + else: + max_idx = bisect_left(_keys[max_pos], max_key) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + _irange_key = irange_key + + def bisect_left(self, val): + """ + Similar to the *bisect* module in the standard library, this returns an + appropriate index to insert *val*. If *val* is already present, the + insertion point will be before (to the left of) any existing entries. + """ + return self._bisect_key_left(self._key(val)) + + def bisect_right(self, val): + """ + Same as *bisect_left*, but if *val* is already present, the insertion + point will be after (to the right of) any existing entries. + """ + return self._bisect_key_right(self._key(val)) + + bisect = bisect_right + + def bisect_key_left(self, key): + """ + Similar to the *bisect* module in the standard library, this returns an + appropriate index to insert a value with a given *key*. If values with + *key* are already present, the insertion point will be before (to the + left of) any existing entries. + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return self._len + + idx = bisect_left(self._keys[pos], key) + + return self._loc(pos, idx) + + _bisect_key_left = bisect_key_left + + def bisect_key_right(self, key): + """ + Same as *bisect_key_left*, but if *key* is already present, the insertion + point will be after (to the right of) any existing entries. + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_right(_maxes, key) + + if pos == len(_maxes): + return self._len + + idx = bisect_right(self._keys[pos], key) + + return self._loc(pos, idx) + + bisect_key = bisect_key_right + _bisect_key_right = bisect_key_right + + def count(self, val): + """Return the number of occurrences of *val* in the list.""" + _maxes = self._maxes + + if not _maxes: + return 0 + + key = self._key(val) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return 0 + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + total = 0 + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return total + if _lists[pos][idx] == val: + total += 1 + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return total + len_sublist = len(_keys[pos]) + idx = 0 + + def copy(self): + """Return a shallow copy of the sorted list.""" + return self.__class__(self, key=self._key, load=self._load) + + __copy__ = copy + + def append(self, val): + """ + Append the element *val* to the list. Raises a ValueError if the *val* + would violate the sort order. + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + key = self._key(val) + + if not _maxes: + _maxes.append(key) + _keys.append([key]) + _lists.append([val]) + self._len = 1 + return + + pos = len(_keys) - 1 + + if key < _keys[pos][-1]: + msg = '{0} not in sort order at index {1}'.format(repr(val), self._len) + raise ValueError(msg) + + _lists[pos].append(val) + _keys[pos].append(key) + _maxes[pos] = key + self._len += 1 + self._expand(pos) + + def extend(self, values): + """ + Extend the list by appending all elements from the *values*. Raises a + ValueError if the sort order would be violated. + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + _load = self._load + + if not isinstance(values, list): + values = list(values) + + keys = list(map(self._key, values)) + + if any(keys[pos - 1] > keys[pos] + for pos in range(1, len(keys))): + raise ValueError('given sequence not in sort order') + + offset = 0 + + if _maxes: + if keys[0] < _keys[-1][-1]: + msg = '{0} not in sort order at index {1}'.format(repr(values[0]), self._len) + raise ValueError(msg) + + if len(_keys[-1]) < self._half: + _lists[-1].extend(values[:_load]) + _keys[-1].extend(keys[:_load]) + _maxes[-1] = _keys[-1][-1] + offset = _load + + len_keys = len(_keys) + + for idx in range(offset, len(keys), _load): + _lists.append(values[idx:(idx + _load)]) + _keys.append(keys[idx:(idx + _load)]) + _maxes.append(_keys[-1][-1]) + + _index = self._index + + if len_keys == len(_keys): + len_index = len(_index) + if len_index > 0: + len_values = len(values) + child = len_index - 1 + while child: + _index[child] += len_values + child = (child - 1) >> 1 + _index[0] += len_values + else: + del _index[:] + + self._len += len(values) + + def insert(self, idx, val): + """ + Insert the element *val* into the list at *idx*. Raises a ValueError if + the *val* at *idx* would violate the sort order. + """ + _len = self._len + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + + if idx < 0: + idx += _len + if idx < 0: + idx = 0 + if idx > _len: + idx = _len + + key = self._key(val) + + if not _maxes: + self._len = 1 + _lists.append([val]) + _keys.append([key]) + _maxes.append(key) + return + + if not idx: + if key > _keys[0][0]: + msg = '{0} not in sort order at index {1}'.format(repr(val), 0) + raise ValueError(msg) + else: + self._len += 1 + _lists[0].insert(0, val) + _keys[0].insert(0, key) + self._expand(0) + return + + if idx == _len: + pos = len(_keys) - 1 + if _keys[pos][-1] > key: + msg = '{0} not in sort order at index {1}'.format(repr(val), _len) + raise ValueError(msg) + else: + self._len += 1 + _lists[pos].append(val) + _keys[pos].append(key) + _maxes[pos] = _keys[pos][-1] + self._expand(pos) + return + + pos, idx = self._pos(idx) + idx_before = idx - 1 + if idx_before < 0: + pos_before = pos - 1 + idx_before = len(_keys[pos_before]) - 1 + else: + pos_before = pos + + before = _keys[pos_before][idx_before] + if before <= key <= _keys[pos][idx]: + self._len += 1 + _lists[pos].insert(idx, val) + _keys[pos].insert(idx, key) + self._expand(pos) + else: + msg = '{0} not in sort order at index {1}'.format(repr(val), idx) + raise ValueError(msg) + + def index(self, val, start=None, stop=None): + """ + Return the smallest *k* such that L[k] == val and i <= k < j`. Raises + ValueError if *val* is not present. *stop* defaults to the end of the + list. *start* defaults to the beginning. Negative indices are supported, + as for slice indices. + """ + _len = self._len + + if not _len: + raise ValueError('{0} is not in list'.format(repr(val))) + + if start is None: + start = 0 + if start < 0: + start += _len + if start < 0: + start = 0 + + if stop is None: + stop = _len + if stop < 0: + stop += _len + if stop > _len: + stop = _len + + if stop <= start: + raise ValueError('{0} is not in list'.format(repr(val))) + + _maxes = self._maxes + key = self._key(val) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + raise ValueError('{0} is not in list'.format(repr(val))) + + stop -= 1 + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + raise ValueError('{0} is not in list'.format(repr(val))) + if _lists[pos][idx] == val: + loc = self._loc(pos, idx) + if start <= loc <= stop: + return loc + elif loc > stop: + break + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + raise ValueError('{0} is not in list'.format(repr(val))) + len_sublist = len(_keys[pos]) + idx = 0 + + raise ValueError('{0} is not in list'.format(repr(val))) + + def __add__(self, that): + """ + Return a new sorted list containing all the elements in *self* and + *that*. Elements in *that* do not need to be properly ordered with + respect to *self*. + """ + values = reduce(iadd, self._lists, []) + values.extend(that) + return self.__class__(values, key=self._key, load=self._load) + + def __mul__(self, that): + """ + Return a new sorted list containing *that* shallow copies of each item + in SortedListWithKey. + """ + values = reduce(iadd, self._lists, []) * that + return self.__class__(values, key=self._key, load=self._load) + + def __imul__(self, that): + """ + Increase the length of the list by appending *that* shallow copies of + each item. + """ + values = reduce(iadd, self._lists, []) * that + self._clear() + self._update(values) + return self + + @recursive_repr + def __repr__(self): + """Return string representation of sequence.""" + temp = '{0}({1}, key={2}, load={3})' + return temp.format( + self.__class__.__name__, + repr(list(self)), + repr(self._key), + repr(self._load) + ) + + def _check(self): + try: + # Check load parameters. + + assert self._load >= 4 + assert self._half == (self._load >> 1) + assert self._twice == (self._load * 2) + + # Check empty sorted list case. + + if self._maxes == []: + assert self._keys == [] + assert self._lists == [] + return + + assert len(self._maxes) > 0 and len(self._keys) > 0 and len(self._lists) > 0 + + # Check all sublists are sorted. + + assert all(sublist[pos - 1] <= sublist[pos] + for sublist in self._keys + for pos in range(1, len(sublist))) + + # Check beginning/end of sublists are sorted. + + for pos in range(1, len(self._keys)): + assert self._keys[pos - 1][-1] <= self._keys[pos][0] + + # Check length of _maxes and _lists match. + + assert len(self._maxes) == len(self._lists) == len(self._keys) + + # Check _keys matches _key mapped to _lists. + + assert all(len(val_list) == len(key_list) + for val_list, key_list in zip(self._lists, self._keys)) + assert all(self._key(val) == key for val, key in + zip((_val for _val_list in self._lists for _val in _val_list), + (_key for _key_list in self._keys for _key in _key_list))) + + # Check _maxes is a map of _keys. + + assert all(self._maxes[pos] == self._keys[pos][-1] + for pos in range(len(self._maxes))) + + # Check load level is less than _twice. + + assert all(len(sublist) <= self._twice for sublist in self._lists) + + # Check load level is greater than _half for all + # but the last sublist. + + assert all(len(self._lists[pos]) >= self._half + for pos in range(0, len(self._lists) - 1)) + + # Check length. + + assert self._len == sum(len(sublist) for sublist in self._lists) + + # Check index. + + if len(self._index): + assert len(self._index) == self._offset + len(self._lists) + assert self._len == self._index[0] + + def test_offset_pos(pos): + "Test positional indexing offset." + from_index = self._index[self._offset + pos] + return from_index == len(self._lists[pos]) + + assert all(test_offset_pos(pos) + for pos in range(len(self._lists))) + + for pos in range(self._offset): + child = (pos << 1) + 1 + if self._index[pos] == 0: + assert child >= len(self._index) + elif child + 1 == len(self._index): + assert self._index[pos] == self._index[child] + else: + child_sum = self._index[child] + self._index[child + 1] + assert self._index[pos] == child_sum + + except: + import sys + import traceback + + traceback.print_exc(file=sys.stdout) + + print('len', self._len) + print('load', self._load, self._half, self._twice) + print('offset', self._offset) + print('len_index', len(self._index)) + print('index', self._index) + print('len_maxes', len(self._maxes)) + print('maxes', self._maxes) + print('len_keys', len(self._keys)) + print('keys', self._keys) + print('len_lists', len(self._lists)) + print('lists', self._lists) + + raise diff --git a/sortedcontainers/sortedcontainers/sortedset.py b/sortedcontainers/sortedcontainers/sortedset.py new file mode 100644 index 0000000000..36b92c826f --- /dev/null +++ b/sortedcontainers/sortedcontainers/sortedset.py @@ -0,0 +1,323 @@ +"""Sorted set implementation. + +""" + +from collections import Set, MutableSet, Sequence +from itertools import chain +import operator as op + +from .sortedlist import SortedList, recursive_repr, SortedListWithKey + +class SortedSet(MutableSet, Sequence): + """ + A `SortedSet` provides the same methods as a `set`. Additionally, a + `SortedSet` maintains its items in sorted order, allowing the `SortedSet` to + be indexed. + + Unlike a `set`, a `SortedSet` requires items be hashable and comparable. + """ + def __init__(self, iterable=None, key=None, load=1000, _set=None): + """ + A `SortedSet` provides the same methods as a `set`. Additionally, a + `SortedSet` maintains its items in sorted order, allowing the + `SortedSet` to be indexed. + + An optional *iterable* provides an initial series of items to populate + the `SortedSet`. + + An optional *key* argument defines a callable that, like the `key` + argument to Python's `sorted` function, extracts a comparison key from + each set item. If no function is specified, the default compares the + set items directly. + + An optional *load* specifies the load-factor of the set. The default + load factor of '1000' works well for sets from tens to tens of millions + of elements. Good practice is to use a value that is the cube root of + the set size. With billions of elements, the best load factor depends + on your usage. It's best to leave the load factor at the default until + you start benchmarking. + """ + # pylint: disable=redefined-variable-type + self._key = key + self._load = load + + self._set = set() if _set is None else _set + + _set = self._set + self.isdisjoint = _set.isdisjoint + self.issubset = _set.issubset + self.issuperset = _set.issuperset + + if key is None: + self._list = SortedList(self._set, load=load) + else: + self._list = SortedListWithKey(self._set, key=key, load=load) + + _list = self._list + self.bisect_left = _list.bisect_left + self.bisect = _list.bisect + self.bisect_right = _list.bisect_right + self.index = _list.index + self.irange = _list.irange + self.islice = _list.islice + + if key is not None: + self.bisect_key_left = _list.bisect_key_left + self.bisect_key_right = _list.bisect_key_right + self.bisect_key = _list.bisect_key + self.irange_key = _list.irange_key + + if iterable is not None: + self._update(iterable) + + def __contains__(self, value): + """Return True if and only if *value* is an element in the set.""" + return value in self._set + + def __getitem__(self, index): + """ + Return the element at position *index*. + + Supports slice notation and negative indexes. + """ + return self._list[index] + + def __delitem__(self, index): + """ + Remove the element at position *index*. + + Supports slice notation and negative indexes. + """ + _set = self._set + _list = self._list + if isinstance(index, slice): + values = _list[index] + _set.difference_update(values) + else: + value = _list[index] + _set.remove(value) + del _list[index] + + def _make_cmp(self, set_op, doc): + "Make comparator method." + def comparer(self, that): + "Compare method for sorted set and set-like object." + # pylint: disable=protected-access + if isinstance(that, SortedSet): + return set_op(self._set, that._set) + elif isinstance(that, Set): + return set_op(self._set, that) + else: + return NotImplemented + + return comparer + + __eq__ = _make_cmp(None, op.eq, 'equal to') + __ne__ = _make_cmp(None, op.ne, 'not equal to') + __lt__ = _make_cmp(None, op.lt, 'a proper subset of') + __gt__ = _make_cmp(None, op.gt, 'a proper superset of') + __le__ = _make_cmp(None, op.le, 'a subset of') + __ge__ = _make_cmp(None, op.ge, 'a superset of') + + def __len__(self): + """Return the number of elements in the set.""" + return len(self._set) + + def __iter__(self): + """ + Return an iterator over the Set. Elements are iterated in their sorted + order. + + Iterating the Set while adding or deleting values may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return iter(self._list) + + def __reversed__(self): + """ + Return an iterator over the Set. Elements are iterated in their reverse + sorted order. + + Iterating the Set while adding or deleting values may raise a + `RuntimeError` or fail to iterate over all entries. + """ + return reversed(self._list) + + def add(self, value): + """Add the element *value* to the set.""" + _set = self._set + if value not in _set: + _set.add(value) + self._list.add(value) + + def clear(self): + """Remove all elements from the set.""" + self._set.clear() + self._list.clear() + + def copy(self): + """Create a shallow copy of the sorted set.""" + return self.__class__(key=self._key, load=self._load, _set=set(self._set)) + + __copy__ = copy + + def count(self, value): + """Return the number of occurrences of *value* in the set.""" + return 1 if value in self._set else 0 + + def discard(self, value): + """ + Remove the first occurrence of *value*. If *value* is not a member, + does nothing. + """ + _set = self._set + if value in _set: + _set.remove(value) + self._list.discard(value) + + def pop(self, index=-1): + """ + Remove and return item at *index* (default last). Raises IndexError if + set is empty or index is out of range. Negative indexes are supported, + as for slice indices. + """ + # pylint: disable=arguments-differ + value = self._list.pop(index) + self._set.remove(value) + return value + + def remove(self, value): + """ + Remove first occurrence of *value*. Raises ValueError if + *value* is not present. + """ + self._set.remove(value) + self._list.remove(value) + + def difference(self, *iterables): + """ + Return a new set with elements in the set that are not in the + *iterables*. + """ + diff = self._set.difference(*iterables) + new_set = self.__class__(key=self._key, load=self._load, _set=diff) + return new_set + + __sub__ = difference + __rsub__ = __sub__ + + def difference_update(self, *iterables): + """ + Update the set, removing elements found in keeping only elements + found in any of the *iterables*. + """ + _set = self._set + values = set(chain(*iterables)) + if (4 * len(values)) > len(_set): + _list = self._list + _set.difference_update(values) + _list.clear() + _list.update(_set) + else: + _discard = self.discard + for value in values: + _discard(value) + return self + + __isub__ = difference_update + + def intersection(self, *iterables): + """ + Return a new set with elements common to the set and all *iterables*. + """ + comb = self._set.intersection(*iterables) + new_set = self.__class__(key=self._key, load=self._load, _set=comb) + return new_set + + __and__ = intersection + __rand__ = __and__ + + def intersection_update(self, *iterables): + """ + Update the set, keeping only elements found in it and all *iterables*. + """ + _set = self._set + _list = self._list + _set.intersection_update(*iterables) + _list.clear() + _list.update(_set) + return self + + __iand__ = intersection_update + + def symmetric_difference(self, that): + """ + Return a new set with elements in either *self* or *that* but not both. + """ + diff = self._set.symmetric_difference(that) + new_set = self.__class__(key=self._key, load=self._load, _set=diff) + return new_set + + __xor__ = symmetric_difference + __rxor__ = __xor__ + + def symmetric_difference_update(self, that): + """ + Update the set, keeping only elements found in either *self* or *that*, + but not in both. + """ + _set = self._set + _list = self._list + _set.symmetric_difference_update(that) + _list.clear() + _list.update(_set) + return self + + __ixor__ = symmetric_difference_update + + def union(self, *iterables): + """ + Return a new SortedSet with elements from the set and all *iterables*. + """ + return self.__class__(chain(iter(self), *iterables), key=self._key, load=self._load) + + __or__ = union + __ror__ = __or__ + + def update(self, *iterables): + """Update the set, adding elements from all *iterables*.""" + _set = self._set + values = set(chain(*iterables)) + if (4 * len(values)) > len(_set): + _list = self._list + _set.update(values) + _list.clear() + _list.update(_set) + else: + _add = self.add + for value in values: + _add(value) + return self + + __ior__ = update + _update = update + + def __reduce__(self): + return (self.__class__, ((), self._key, self._load, self._set)) + + @recursive_repr + def __repr__(self): + temp = '{0}({1}, key={2}, load={3})' + return temp.format( + self.__class__.__name__, + repr(list(self)), + repr(self._key), + repr(self._load) + ) + + def _check(self): + # pylint: disable=protected-access + self._list._check() + assert len(self._set) == len(self._list) + _set = self._set + assert all(val in _set for val in self._list) From 9a6f597b14d518fe4c4c400f2ffb53631b134905 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:23:44 +0100 Subject: [PATCH 08/65] sortedcontainers 2 --- .../sortedcontainers/sorteddict.py | 138 ++++++------------ .../sortedcontainers/sortedlist.py | 21 +-- 2 files changed, 49 insertions(+), 110 deletions(-) diff --git a/sortedcontainers/sortedcontainers/sorteddict.py b/sortedcontainers/sortedcontainers/sorteddict.py index d82aa86bca..d78f1099e0 100644 --- a/sortedcontainers/sortedcontainers/sorteddict.py +++ b/sortedcontainers/sortedcontainers/sorteddict.py @@ -3,7 +3,6 @@ """ from collections import Set, Sequence -from sys import hexversion from .sortedlist import SortedList, recursive_repr, SortedListWithKey from .sortedset import SortedSet @@ -337,18 +336,17 @@ def update(self, *args, **kwargs): _update = update - if hexversion >= 0x02070000: - def viewkeys(self): - "Return ``KeysView`` of dictionary keys." - return KeysView(self) + def viewkeys(self): + "Return ``KeysView`` of dictionary keys." + return KeysView(self) - def viewvalues(self): - "Return ``ValuesView`` of dictionary values." - return ValuesView(self) + def viewvalues(self): + "Return ``ValuesView`` of dictionary values." + return ValuesView(self) - def viewitems(self): - "Return ``ItemsView`` of dictionary (key, value) item pairs." - return ItemsView(self) + def viewitems(self): + "Return ``ItemsView`` of dictionary (key, value) item pairs." + return ItemsView(self) def __reduce__(self): return (self.__class__, (self._key, self._load, list(self._iteritems()))) @@ -380,22 +378,13 @@ class KeysView(Set, Sequence): The KeysView class implements the Set and Sequence Abstract Base Classes. """ - if hexversion < 0x03000000: - def __init__(self, sorted_dict): - """ - Initialize a KeysView from a SortedDict container as *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._list = sorted_dict._list - self._view = sorted_dict._dict.viewkeys() - else: - def __init__(self, sorted_dict): - """ - Initialize a KeysView from a SortedDict container as *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._list = sorted_dict._list - self._view = sorted_dict._dict.keys() + def __init__(self, sorted_dict): + """ + Initialize a KeysView from a SortedDict container as *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._list = sorted_dict._list + self._view = sorted_dict._dict.keys() def __len__(self): """Return the number of entries in the dictionary.""" return len(self._view) @@ -468,14 +457,9 @@ def __sub__(self, that): def __xor__(self, that): """Return a SortedSet of the symmetric difference of self and *that*.""" return SortedSet(self._view ^ that) - if hexversion < 0x03000000: - def isdisjoint(self, that): - """Return True if and only if *that* is disjoint with self.""" - return not any(key in self._list for key in that) - else: - def isdisjoint(self, that): - """Return True if and only if *that* is disjoint with self.""" - return self._view.isdisjoint(that) + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + return self._view.isdisjoint(that) @recursive_repr def __repr__(self): return 'SortedDict_keys({0})'.format(repr(list(self))) @@ -489,26 +473,15 @@ class ValuesView(Sequence): The ValuesView class implements the Sequence Abstract Base Class. """ - if hexversion < 0x03000000: - def __init__(self, sorted_dict): - """ - Initialize a ValuesView from a SortedDict container as - *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._dict = sorted_dict - self._list = sorted_dict._list - self._view = sorted_dict._dict.viewvalues() - else: - def __init__(self, sorted_dict): - """ - Initialize a ValuesView from a SortedDict container as - *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._dict = sorted_dict - self._list = sorted_dict._list - self._view = sorted_dict._dict.values() + def __init__(self, sorted_dict): + """ + Initialize a ValuesView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.values() def __len__(self): """Return the number of entries in the dictionary.""" return len(self._dict) @@ -559,14 +532,9 @@ def index(self, value): if value == val: return idx raise ValueError('{0} is not in dict'.format(repr(value))) - if hexversion < 0x03000000: - def count(self, value): - """Return the number of occurrences of *value* in self.""" - return sum(1 for val in self._dict.itervalues() if val == value) - else: - def count(self, value): - """Return the number of occurrences of *value* in self.""" - return sum(1 for val in self._dict.values() if val == value) + def count(self, value): + """Return the number of occurrences of *value* in self.""" + return sum(1 for val in self._dict.values() if val == value) def __lt__(self, that): raise TypeError def __gt__(self, that): @@ -598,26 +566,15 @@ class ItemsView(Set, Sequence): However, the set-like operations (``&``, ``|``, ``-``, ``^``) will only operate correctly if all of the dictionary's values are hashable. """ - if hexversion < 0x03000000: - def __init__(self, sorted_dict): - """ - Initialize an ItemsView from a SortedDict container as - *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._dict = sorted_dict - self._list = sorted_dict._list - self._view = sorted_dict._dict.viewitems() - else: - def __init__(self, sorted_dict): - """ - Initialize an ItemsView from a SortedDict container as - *sorted_dict*. - """ - # pylint: disable=super-init-not-called, protected-access - self._dict = sorted_dict - self._list = sorted_dict._list - self._view = sorted_dict._dict.items() + def __init__(self, sorted_dict): + """ + Initialize an ItemsView from a SortedDict container as + *sorted_dict*. + """ + # pylint: disable=super-init-not-called, protected-access + self._dict = sorted_dict + self._list = sorted_dict._list + self._view = sorted_dict._dict.items() def __len__(self): """Return the number of entries in the dictionary.""" return len(self._view) @@ -703,18 +660,9 @@ def __sub__(self, that): def __xor__(self, that): """Return a SortedSet of the symmetric difference of self and *that*.""" return SortedSet(self._view ^ that) - if hexversion < 0x03000000: - def isdisjoint(self, that): - """Return True if and only if *that* is disjoint with self.""" - _dict = self._dict - for key, value in that: - if key in _dict and _dict[key] == value: - return False - return True - else: - def isdisjoint(self, that): - """Return True if and only if *that* is disjoint with self.""" - return self._view.isdisjoint(that) + def isdisjoint(self, that): + """Return True if and only if *that* is disjoint with self.""" + return self._view.isdisjoint(that) @recursive_repr def __repr__(self): return 'SortedDict_items({0})'.format(repr(list(self))) diff --git a/sortedcontainers/sortedcontainers/sortedlist.py b/sortedcontainers/sortedcontainers/sortedlist.py index 5d1245f972..3a87fc75f3 100644 --- a/sortedcontainers/sortedcontainers/sortedlist.py +++ b/sortedcontainers/sortedcontainers/sortedlist.py @@ -12,21 +12,12 @@ from math import log as log_e import operator as op from operator import iadd, add -from sys import hexversion - -if hexversion < 0x03000000: - from itertools import izip as zip - from itertools import imap as map - try: - from thread import get_ident - except ImportError: - from dummy_thread import get_ident -else: - from functools import reduce - try: - from _thread import get_ident - except ImportError: - from _dummy_thread import get_ident # pylint: disable=import-error + +from functools import reduce +try: + from _thread import get_ident +except ImportError: + from _dummy_thread import get_ident # pylint: disable=import-error def recursive_repr(func): """Decorator to prevent infinite repr recursion.""" From 894429635c81ce25d02417ff799ce0f4345cdd9d Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:24:17 +0100 Subject: [PATCH 09/65] unittest: implement assertRaisesRegex --- unittest/unittest.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/unittest/unittest.py b/unittest/unittest.py index af3d438f06..015b8669ce 100644 --- a/unittest/unittest.py +++ b/unittest/unittest.py @@ -165,6 +165,22 @@ def assertRaises(self, exc, func=None, *args, **kwargs): def assertWarns(self, warn): return NullContext() + def assertRaisesRegex(self, exc, regexp, func=None, *args, **kwargs): + import re + if func is None: + return AssertRaisesContext(exc) + + try: + func(*args, **kwargs) + assert False, "%r not raised" % exc + except Exception as e: + if not isinstance(e, exc): + raise + r = re.compile(exc_re) + if r.search(repr(e)): + return + raise + def skip(msg): def _decor(fun): From e5adaf563192a40984da8011c55439b450e20f61 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:25:03 +0100 Subject: [PATCH 10/65] do a real import of collections.abc --- collections.abc/collections/abc.py | 968 ++++++++++++++++++++++++++++- 1 file changed, 964 insertions(+), 4 deletions(-) diff --git a/collections.abc/collections/abc.py b/collections.abc/collections/abc.py index 9ca6cacf93..83b16871a8 100644 --- a/collections.abc/collections/abc.py +++ b/collections.abc/collections/abc.py @@ -1,5 +1,965 @@ +# Copyright 2007 Google, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. -# this is so not-right it's not even wrong -Mapping = dict -MutableMapping = dict -Sequence = (tuple, list) # only useful for simple isinstance tests +"""Abstract Base Classes (ABCs) for collections, according to PEP 3119. + +Unit tests are in test_collections. +""" + +from abc import abstractmethod +import sys + +__all__ = ["Awaitable", "Coroutine", + "AsyncIterable", "AsyncIterator", "AsyncGenerator", + "Hashable", "Iterable", "Iterator", "Generator", "Reversible", + "Sized", "Container", "Callable", "Collection", + "Set", "MutableSet", + "Mapping", "MutableMapping", + "MappingView", "KeysView", "ItemsView", "ValuesView", + "Sequence", "MutableSequence", + "ByteString", + ] + +# This module has been renamed from collections.abc to _collections_abc to +# speed up interpreter startup. Some of the types such as MutableMapping are +# required early but collections module imports a lot of other modules. +# See issue #19218 +__name__ = "collections.abc" + +# Private list of types that we want to register with the various ABCs +# so that they will pass tests like: +# it = iter(somebytearray) +# assert isinstance(it, Iterable) +# Note: in other implementations, these types might not be distinct +# and they may have their own implementation specific types that +# are not included on this list. +bytes_iterator = type(iter(b'')) +bytearray_iterator = type(iter(bytearray())) +#callable_iterator = ??? +dict_keyiterator = type(iter({}.keys())) +dict_valueiterator = type(iter({}.values())) +dict_itemiterator = type(iter({}.items())) +list_iterator = type(iter([])) +list_reverseiterator = type(iter(reversed([]))) +range_iterator = type(iter(range(0))) +longrange_iterator = range_iterator +set_iterator = type(iter(set())) +str_iterator = type(iter("")) +tuple_iterator = type(iter(())) +zip_iterator = type(iter(zip())) +## views ## +dict_keys = type({}.keys()) +dict_values = type({}.values()) +dict_items = type({}.items()) +## misc ## +mappingproxy = dict +generator = type((lambda: (yield))()) +## coroutine ## +async def _coro(): pass +_coro = _coro() +coroutine = type(_coro) +_coro.close() # Prevent ResourceWarning +del _coro +## asynchronous generator ## +async def _ag(): yield +_ag = _ag() +async_generator = type(_ag) +del _ag + + +### ONE-TRICK PONIES ### + +def _check_methods(C, *methods): + mro = C.__mro__ + for method in methods: + for B in mro: + if method in B.__dict__: + if B.__dict__[method] is None: + return NotImplemented + break + else: + return NotImplemented + return True + +class Hashable: + + __slots__ = () + + @abstractmethod + def __hash__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Hashable: + return _check_methods(C, "__hash__") + return NotImplemented + + +class Awaitable: + + __slots__ = () + + @abstractmethod + def __await__(self): + yield + + @classmethod + def __subclasshook__(cls, C): + if cls is Awaitable: + return _check_methods(C, "__await__") + return NotImplemented + + +class Coroutine(Awaitable): + + __slots__ = () + + @abstractmethod + def send(self, value): + """Send a value into the coroutine. + Return next yielded value or raise StopIteration. + """ + raise StopIteration + + @abstractmethod + def throw(self, typ, val=None, tb=None): + """Raise an exception in the coroutine. + Return next yielded value or raise StopIteration. + """ + if val is None: + if tb is None: + raise typ + val = typ() + if tb is not None: + val = val.with_traceback(tb) + raise val + + def close(self): + """Raise GeneratorExit inside coroutine. + """ + try: + self.throw(GeneratorExit) + except (GeneratorExit, StopIteration): + pass + else: + raise RuntimeError("coroutine ignored GeneratorExit") + + @classmethod + def __subclasshook__(cls, C): + if cls is Coroutine: + return _check_methods(C, '__await__', 'send', 'throw', 'close') + return NotImplemented + + +class AsyncIterable: + + __slots__ = () + + @abstractmethod + def __aiter__(self): + return AsyncIterator() + + @classmethod + def __subclasshook__(cls, C): + if cls is AsyncIterable: + return _check_methods(C, "__aiter__") + return NotImplemented + + +class AsyncIterator(AsyncIterable): + + __slots__ = () + + @abstractmethod + async def __anext__(self): + """Return the next item or raise StopAsyncIteration when exhausted.""" + raise StopAsyncIteration + + def __aiter__(self): + return self + + @classmethod + def __subclasshook__(cls, C): + if cls is AsyncIterator: + return _check_methods(C, "__anext__", "__aiter__") + return NotImplemented + + +class AsyncGenerator(AsyncIterator): + + __slots__ = () + + async def __anext__(self): + """Return the next item from the asynchronous generator. + When exhausted, raise StopAsyncIteration. + """ + return await self.asend(None) + + @abstractmethod + async def asend(self, value): + """Send a value into the asynchronous generator. + Return next yielded value or raise StopAsyncIteration. + """ + raise StopAsyncIteration + + @abstractmethod + async def athrow(self, typ, val=None, tb=None): + """Raise an exception in the asynchronous generator. + Return next yielded value or raise StopAsyncIteration. + """ + if val is None: + if tb is None: + raise typ + val = typ() + if tb is not None: + val = val.with_traceback(tb) + raise val + + async def aclose(self): + """Raise GeneratorExit inside coroutine. + """ + try: + await self.athrow(GeneratorExit) + except (GeneratorExit, StopAsyncIteration): + pass + else: + raise RuntimeError("asynchronous generator ignored GeneratorExit") + + @classmethod + def __subclasshook__(cls, C): + if cls is AsyncGenerator: + return _check_methods(C, '__aiter__', '__anext__', + 'asend', 'athrow', 'aclose') + return NotImplemented + + + +class Iterable: + + __slots__ = () + + @abstractmethod + def __iter__(self): + while False: + yield None + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterable: + return _check_methods(C, "__iter__") + return NotImplemented + + +class Iterator(Iterable): + + __slots__ = () + + @abstractmethod + def __next__(self): + 'Return the next item from the iterator. When exhausted, raise StopIteration' + raise StopIteration + + def __iter__(self): + return self + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterator: + return _check_methods(C, '__iter__', '__next__') + return NotImplemented + + +class Reversible(Iterable): + + __slots__ = () + + @abstractmethod + def __reversed__(self): + while False: + yield None + + @classmethod + def __subclasshook__(cls, C): + if cls is Reversible: + return _check_methods(C, "__reversed__", "__iter__") + return NotImplemented + + +class Generator(Iterator): + + __slots__ = () + + def __next__(self): + """Return the next item from the generator. + When exhausted, raise StopIteration. + """ + return self.send(None) + + @abstractmethod + def send(self, value): + """Send a value into the generator. + Return next yielded value or raise StopIteration. + """ + raise StopIteration + + @abstractmethod + def throw(self, typ, val=None, tb=None): + """Raise an exception in the generator. + Return next yielded value or raise StopIteration. + """ + if val is None: + if tb is None: + raise typ + val = typ() + if tb is not None: + val = val.with_traceback(tb) + raise val + + def close(self): + """Raise GeneratorExit inside generator. + """ + try: + self.throw(GeneratorExit) + except (GeneratorExit, StopIteration): + pass + else: + raise RuntimeError("generator ignored GeneratorExit") + + @classmethod + def __subclasshook__(cls, C): + if cls is Generator: + return _check_methods(C, '__iter__', '__next__', + 'send', 'throw', 'close') + return NotImplemented + + +class Sized: + + __slots__ = () + + @abstractmethod + def __len__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Sized: + return _check_methods(C, "__len__") + return NotImplemented + + +class Container: + + __slots__ = () + + @abstractmethod + def __contains__(self, x): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Container: + return _check_methods(C, "__contains__") + return NotImplemented + +class Collection(Sized, Iterable, Container): + + __slots__ = () + + @classmethod + def __subclasshook__(cls, C): + if cls is Collection: + return _check_methods(C, "__len__", "__iter__", "__contains__") + return NotImplemented + +class Callable: + + __slots__ = () + + @abstractmethod + def __call__(self, *args, **kwds): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Callable: + return _check_methods(C, "__call__") + return NotImplemented + + +### SETS ### + + +class Set(Collection): + + """A set is a finite, iterable container. + + This class provides concrete generic implementations of all + methods except for __contains__, __iter__ and __len__. + + To override the comparisons (presumably for speed, as the + semantics are fixed), redefine __le__ and __ge__, + then the other operations will automatically follow suit. + """ + + __slots__ = () + + def __le__(self, other): + if not isinstance(other, Set): + return NotImplemented + if len(self) > len(other): + return False + for elem in self: + if elem not in other: + return False + return True + + def __lt__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) < len(other) and self.__le__(other) + + def __gt__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) > len(other) and self.__ge__(other) + + def __ge__(self, other): + if not isinstance(other, Set): + return NotImplemented + if len(self) < len(other): + return False + for elem in other: + if elem not in self: + return False + return True + + def __eq__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) == len(other) and self.__le__(other) + + @classmethod + def _from_iterable(cls, it): + '''Construct an instance of the class from any iterable input. + + Must override this method if the class constructor signature + does not accept an iterable for an input. + ''' + return cls(it) + + def __and__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + return self._from_iterable(value for value in other if value in self) + + __rand__ = __and__ + + def isdisjoint(self, other): + 'Return True if two sets have a null intersection.' + for value in other: + if value in self: + return False + return True + + def __or__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + chain = (e for s in (self, other) for e in s) + return self._from_iterable(chain) + + __ror__ = __or__ + + def __sub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in self + if value not in other) + + def __rsub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in other + if value not in self) + + def __xor__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return (self - other) | (other - self) + + __rxor__ = __xor__ + + def _hash(self): + """Compute the hash value of a set. + + Note that we don't define __hash__: not all sets are hashable. + But if you define a hashable set type, its __hash__ should + call this function. + + This must be compatible __eq__. + + All sets ought to compare equal if they contain the same + elements, regardless of how they are implemented, and + regardless of the order of the elements; so there's not much + freedom for __eq__ or __hash__. We match the algorithm used + by the built-in frozenset type. + """ + MAX = sys.maxsize + MASK = 2 * MAX + 1 + n = len(self) + h = 1927868237 * (n + 1) + h &= MASK + for x in self: + hx = hash(x) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + h = h * 69069 + 907133923 + h &= MASK + if h > MAX: + h -= MASK + 1 + if h == -1: + h = 590923713 + return h + + +class MutableSet(Set): + """A mutable set is a finite, iterable container. + + This class provides concrete generic implementations of all + methods except for __contains__, __iter__, __len__, + add(), and discard(). + + To override the comparisons (presumably for speed, as the + semantics are fixed), all you have to do is redefine __le__ and + then the other operations will automatically follow suit. + """ + + __slots__ = () + + @abstractmethod + def add(self, value): + """Add an element.""" + raise NotImplementedError + + @abstractmethod + def discard(self, value): + """Remove an element. Do not raise an exception if absent.""" + raise NotImplementedError + + def remove(self, value): + """Remove an element. If not a member, raise a KeyError.""" + if value not in self: + raise KeyError(value) + self.discard(value) + + def pop(self): + """Return the popped value. Raise KeyError if empty.""" + it = iter(self) + try: + value = next(it) + except StopIteration: + raise KeyError from None + self.discard(value) + return value + + def clear(self): + """This is slow (creates N new iterators!) but effective.""" + try: + while True: + self.pop() + except KeyError: + pass + + def __ior__(self, it): + for value in it: + self.add(value) + return self + + def __iand__(self, it): + for value in (self - it): + self.discard(value) + return self + + def __ixor__(self, it): + if it is self: + self.clear() + else: + if not isinstance(it, Set): + it = self._from_iterable(it) + for value in it: + if value in self: + self.discard(value) + else: + self.add(value) + return self + + def __isub__(self, it): + if it is self: + self.clear() + else: + for value in it: + self.discard(value) + return self + + +### MAPPINGS ### + + +class Mapping(Collection): + + __slots__ = () + + """A Mapping is a generic container for associating key/value + pairs. + + This class provides concrete generic implementations of all + methods except for __getitem__, __iter__, and __len__. + + """ + + @abstractmethod + def __getitem__(self, key): + raise KeyError + + def get(self, key, default=None): + 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.' + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def keys(self): + "D.keys() -> a set-like object providing a view on D's keys" + return KeysView(self) + + def items(self): + "D.items() -> a set-like object providing a view on D's items" + return ItemsView(self) + + def values(self): + "D.values() -> an object providing a view on D's values" + return ValuesView(self) + + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + __reversed__ = None + + + +class MappingView(Sized): + + __slots__ = '_mapping', + + def __init__(self, mapping): + self._mapping = mapping + + def __len__(self): + return len(self._mapping) + + def __repr__(self): + return '{0.__class__.__name__}({0._mapping!r})'.format(self) + + +class KeysView(MappingView, Set): + + __slots__ = () + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, key): + return key in self._mapping + + def __iter__(self): + yield from self._mapping + + +class ItemsView(MappingView, Set): + + __slots__ = () + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, item): + key, value = item + try: + v = self._mapping[key] + except KeyError: + return False + else: + return v is value or v == value + + def __iter__(self): + for key in self._mapping: + yield (key, self._mapping[key]) + + +class ValuesView(MappingView, Collection): + + __slots__ = () + + def __contains__(self, value): + for key in self._mapping: + v = self._mapping[key] + if v is value or v == value: + return True + return False + + def __iter__(self): + for key in self._mapping: + yield self._mapping[key] + + +class MutableMapping(Mapping): + + __slots__ = () + + """A MutableMapping is a generic container for associating + key/value pairs. + + This class provides concrete generic implementations of all + methods except for __getitem__, __setitem__, __delitem__, + __iter__, and __len__. + + """ + + @abstractmethod + def __setitem__(self, key, value): + raise KeyError + + @abstractmethod + def __delitem__(self, key): + raise KeyError + + __marker = object() + + def pop(self, key, default=__marker): + '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + ''' + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + '''D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + ''' + try: + key = next(iter(self)) + except StopIteration: + raise KeyError from None + value = self[key] + del self[key] + return key, value + + def clear(self): + 'D.clear() -> None. Remove all items from D.' + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + ''' D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + In either case, this is followed by: for k, v in F.items(): D[k] = v + ''' + if not args: + raise TypeError("descriptor 'update' of 'MutableMapping' object " + "needs an argument") + self, *args = args + if len(args) > 1: + raise TypeError('update expected at most 1 arguments, got %d' % + len(args)) + if args: + other = args[0] + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + 'D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D' + try: + return self[key] + except KeyError: + self[key] = default + return default + + +### SEQUENCES ### + + +class Sequence(Reversible, Collection): + + """All the operations on a read-only sequence. + + Concrete subclasses must override __new__ or __init__, + __getitem__, and __len__. + """ + + __slots__ = () + + @abstractmethod + def __getitem__(self, index): + raise IndexError + + def __iter__(self): + i = 0 + try: + while True: + v = self[i] + yield v + i += 1 + except IndexError: + return + + def __contains__(self, value): + for v in self: + if v is value or v == value: + return True + return False + + def __reversed__(self): + for i in reversed(range(len(self))): + yield self[i] + + def index(self, value, start=0, stop=None): + '''S.index(value, [start, [stop]]) -> integer -- return first index of value. + Raises ValueError if the value is not present. + + Supporting start and stop arguments is optional, but + recommended. + ''' + if start is not None and start < 0: + start = max(len(self) + start, 0) + if stop is not None and stop < 0: + stop += len(self) + + i = start + while stop is None or i < stop: + try: + v = self[i] + if v is value or v == value: + return i + except IndexError: + break + i += 1 + raise ValueError + + def count(self, value): + 'S.count(value) -> integer -- return number of occurrences of value' + return sum(1 for v in self if v is value or v == value) + + +class ByteString(Sequence): + + """This unifies bytes and bytearray. + + XXX Should add all their methods. + """ + + __slots__ = () + + +class MutableSequence(Sequence): + + __slots__ = () + + """All the operations on a read-write sequence. + + Concrete subclasses must provide __new__ or __init__, + __getitem__, __setitem__, __delitem__, __len__, and insert(). + + """ + + @abstractmethod + def __setitem__(self, index, value): + raise IndexError + + @abstractmethod + def __delitem__(self, index): + raise IndexError + + @abstractmethod + def insert(self, index, value): + 'S.insert(index, value) -- insert value before index' + raise IndexError + + def append(self, value): + 'S.append(value) -- append value to the end of the sequence' + self.insert(len(self), value) + + def clear(self): + 'S.clear() -> None -- remove all items from S' + try: + while True: + self.pop() + except IndexError: + pass + + def reverse(self): + 'S.reverse() -- reverse *IN PLACE*' + n = len(self) + for i in range(n//2): + self[i], self[n-i-1] = self[n-i-1], self[i] + + def extend(self, values): + 'S.extend(iterable) -- extend sequence by appending elements from the iterable' + for v in values: + self.append(v) + + def pop(self, index=-1): + '''S.pop([index]) -> item -- remove and return item at index (default last). + Raise IndexError if list is empty or index is out of range. + ''' + v = self[index] + del self[index] + return v + + def remove(self, value): + '''S.remove(value) -- remove first occurrence of value. + Raise ValueError if the value is not present. + ''' + del self[self.index(value)] + + def __iadd__(self, values): + self.extend(values) + return self From 0e201a93debf28552d4eefea33dee8d71f2986de Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:26:02 +0100 Subject: [PATCH 11/65] uptick --- collections.abc/metadata.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/collections.abc/metadata.txt b/collections.abc/metadata.txt index a8ff4f11f6..8f2fec2d30 100644 --- a/collections.abc/metadata.txt +++ b/collections.abc/metadata.txt @@ -1,3 +1,3 @@ srctype = micropython-lib type = package -version = 0.1.0 +version = 0.2.0 From dba4a2606f256c208130e06609221db52ddafcf1 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 19:26:45 +0100 Subject: [PATCH 12/65] Add deque.clear() --- collections.deque/collections/deque.py | 3 +++ collections.deque/metadata.txt | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/collections.deque/collections/deque.py b/collections.deque/collections/deque.py index b284e5f385..bb9f015c17 100644 --- a/collections.deque/collections/deque.py +++ b/collections.deque/collections/deque.py @@ -6,6 +6,9 @@ def __init__(self, iterable=None): else: self.q = list(iterable) + def clear(self): + self.q = [] + def popleft(self): return self.q.pop(0) diff --git a/collections.deque/metadata.txt b/collections.deque/metadata.txt index 7c6b2011d4..95a408767a 100644 --- a/collections.deque/metadata.txt +++ b/collections.deque/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = package -version = 0.1.3 +version = 0.1.4 From 89e5836badee0623a1b76e5578f51991d8f17482 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:01:01 +0100 Subject: [PATCH 13/65] immutables port to micropython --- immutables/immutables/__init__.py | 11 + immutables/immutables/_map.c | 4167 +++++++++++++++++++++++++++++ immutables/immutables/_map.h | 107 + immutables/immutables/map.py | 813 ++++++ immutables/metadata.txt | 3 + 5 files changed, 5101 insertions(+) create mode 100644 immutables/immutables/__init__.py create mode 100644 immutables/immutables/_map.c create mode 100644 immutables/immutables/_map.h create mode 100644 immutables/immutables/map.py create mode 100644 immutables/metadata.txt diff --git a/immutables/immutables/__init__.py b/immutables/immutables/__init__.py new file mode 100644 index 0000000000..60a071c1f0 --- /dev/null +++ b/immutables/immutables/__init__.py @@ -0,0 +1,11 @@ +try: + from ._map import Map +except ImportError: + from .map import Map +else: + import collections.abc as _abc + _abc.Mapping.register(Map) + + +__all__ = 'Map', +__version__ = '0.9' diff --git a/immutables/immutables/_map.c b/immutables/immutables/_map.c new file mode 100644 index 0000000000..f9c87723dc --- /dev/null +++ b/immutables/immutables/_map.c @@ -0,0 +1,4167 @@ +#include /* For offsetof */ +#include "_map.h" + + +/* +This file provides an implemention of an immutable mapping using the +Hash Array Mapped Trie (or HAMT) datastructure. + +This design allows to have: + +1. Efficient copy: immutable mappings can be copied by reference, + making it an O(1) operation. + +2. Efficient mutations: due to structural sharing, only a portion of + the trie needs to be copied when the collection is mutated. The + cost of set/delete operations is O(log N). + +3. Efficient lookups: O(log N). + +(where N is number of key/value items in the immutable mapping.) + + +HAMT +==== + +The core idea of HAMT is that the shape of the trie is encoded into the +hashes of keys. + +Say we want to store a K/V pair in our mapping. First, we calculate the +hash of K, let's say it's 19830128, or in binary: + + 0b1001011101001010101110000 = 19830128 + +Now let's partition this bit representation of the hash into blocks of +5 bits each: + + 0b00_00000_10010_11101_00101_01011_10000 = 19830128 + (6) (5) (4) (3) (2) (1) + +Each block of 5 bits represents a number between 0 and 31. So if we have +a tree that consists of nodes, each of which is an array of 32 pointers, +those 5-bit blocks will encode a position on a single tree level. + +For example, storing the key K with hash 19830128, results in the following +tree structure: + + (array of 32 pointers) + +---+ -- +----+----+----+ -- +----+ + root node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b10000 = 16 (1) + (level 1) +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+ -- +----+ + a 2nd level node | 0 | .. | 10 | 11 | 12 | .. | 31 | 0b01011 = 11 (2) + +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+ -- +----+ + a 3rd level node | 0 | .. | 04 | 05 | 06 | .. | 31 | 0b00101 = 5 (3) + +---+ -- +----+----+----+ -- +----+ + | + +---+ -- +----+----+----+----+ + a 4th level node | 0 | .. | 04 | 29 | 30 | 31 | 0b11101 = 29 (4) + +---+ -- +----+----+----+----+ + | + +---+ -- +----+----+----+ -- +----+ + a 5th level node | 0 | .. | 17 | 18 | 19 | .. | 31 | 0b10010 = 18 (5) + +---+ -- +----+----+----+ -- +----+ + | + +--------------+ + | + +---+ -- +----+----+----+ -- +----+ + a 6th level node | 0 | .. | 15 | 16 | 17 | .. | 31 | 0b00000 = 0 (6) + +---+ -- +----+----+----+ -- +----+ + | + V -- our value (or collision) + +To rehash: for a K/V pair, the hash of K encodes where in the tree V will +be stored. + +To optimize memory footprint and handle hash collisions, our implementation +uses three different types of nodes: + + * A Bitmap node; + * An Array node; + * A Collision node. + +Because we implement an immutable dictionary, our nodes are also +immutable. Therefore, when we need to modify a node, we copy it, and +do that modification to the copy. + + +Array Nodes +----------- + +These nodes are very simple. Essentially they are arrays of 32 pointers +we used to illustrate the high-level idea in the previous section. + +We use Array nodes only when we need to store more than 16 pointers +in a single node. + +Array nodes do not store key objects or value objects. They are used +only as an indirection level - their pointers point to other nodes in +the tree. + + +Bitmap Node +----------- + +Allocating a new 32-pointers array for every node of our tree would be +very expensive. Unless we store millions of keys, most of tree nodes would +be very sparse. + +When we have less than 16 elements in a node, we don't want to use the +Array node, that would mean that we waste a lot of memory. Instead, +we can use bitmap compression and can have just as many pointers +as we need! + +Bitmap nodes consist of two fields: + +1. An array of pointers. If a Bitmap node holds N elements, the + array will be of N pointers. + +2. A 32bit integer -- a bitmap field. If an N-th bit is set in the + bitmap, it means that the node has an N-th element. + +For example, say we need to store a 3 elements sparse array: + + +---+ -- +---+ -- +----+ -- +----+ + | 0 | .. | 4 | .. | 11 | .. | 17 | + +---+ -- +---+ -- +----+ -- +----+ + | | | + o1 o2 o3 + +We allocate a three-pointer Bitmap node. Its bitmap field will be +then set to: + + 0b_00100_00010_00000_10000 == (1 << 17) | (1 << 11) | (1 << 4) + +To check if our Bitmap node has an I-th element we can do: + + bitmap & (1 << I) + + +And here's a formula to calculate a position in our pointer array +which would correspond to an I-th element: + + popcount(bitmap & ((1 << I) - 1)) + + +Let's break it down: + + * `popcount` is a function that returns a number of bits set to 1; + + * `((1 << I) - 1)` is a mask to filter the bitmask to contain bits + set to the *right* of our bit. + + +So for our 17, 11, and 4 indexes: + + * bitmap & ((1 << 17) - 1) == 0b100000010000 => 2 bits are set => index is 2. + + * bitmap & ((1 << 11) - 1) == 0b10000 => 1 bit is set => index is 1. + + * bitmap & ((1 << 4) - 1) == 0b0 => 0 bits are set => index is 0. + + +To conclude: Bitmap nodes are just like Array nodes -- they can store +a number of pointers, but use bitmap compression to eliminate unused +pointers. + + +Bitmap nodes have two pointers for each item: + + +----+----+----+----+ -- +----+----+ + | k1 | v1 | k2 | v2 | .. | kN | vN | + +----+----+----+----+ -- +----+----+ + +When kI == NULL, vI points to another tree level. + +When kI != NULL, the actual key object is stored in kI, and its +value is stored in vI. + + +Collision Nodes +--------------- + +Collision nodes are simple arrays of pointers -- two pointers per +key/value. When there's a hash collision, say for k1/v1 and k2/v2 +we have `hash(k1)==hash(k2)`. Then our collision node will be: + + +----+----+----+----+ + | k1 | v1 | k2 | v2 | + +----+----+----+----+ + + +Tree Structure +-------------- + +All nodes are PyObjects. + +The `MapObject` object has a pointer to the root node (h_root), +and has a length field (h_count). + +High-level functions accept a MapObject object and dispatch to +lower-level functions depending on what kind of node h_root points to. + + +Operations +========== + +There are three fundamental operations on an immutable dictionary: + +1. "o.assoc(k, v)" will return a new immutable dictionary, that will be + a copy of "o", but with the "k/v" item set. + + Functions in this file: + + map_node_assoc, map_node_bitmap_assoc, + map_node_array_assoc, map_node_collision_assoc + + `map_node_assoc` function accepts a node object, and calls + other functions depending on its actual type. + +2. "o.find(k)" will lookup key "k" in "o". + + Functions: + + map_node_find, map_node_bitmap_find, + map_node_array_find, map_node_collision_find + +3. "o.without(k)" will return a new immutable dictionary, that will be + a copy of "o", buth without the "k" key. + + Functions: + + map_node_without, map_node_bitmap_without, + map_node_array_without, map_node_collision_without + + +Further Reading +=============== + +1. http://blog.higher-order.net/2009/09/08/understanding-clojures-persistenthashmap-deftwice.html + +2. http://blog.higher-order.net/2010/08/16/assoc-and-clojures-persistenthashmap-part-ii.html + +3. Clojure's PersistentHashMap implementation: + https://github.com/clojure/clojure/blob/master/src/jvm/clojure/lang/PersistentHashMap.java +*/ + + +#define IS_ARRAY_NODE(node) (Py_TYPE(node) == &_Map_ArrayNode_Type) +#define IS_BITMAP_NODE(node) (Py_TYPE(node) == &_Map_BitmapNode_Type) +#define IS_COLLISION_NODE(node) (Py_TYPE(node) == &_Map_CollisionNode_Type) + + +/* Return type for 'find' (lookup a key) functions. + + * F_ERROR - an error occurred; + * F_NOT_FOUND - the key was not found; + * F_FOUND - the key was found. +*/ +typedef enum {F_ERROR, F_NOT_FOUND, F_FOUND} map_find_t; + + +/* Return type for 'without' (delete a key) functions. + + * W_ERROR - an error occurred; + * W_NOT_FOUND - the key was not found: there's nothing to delete; + * W_EMPTY - the key was found: the node/tree would be empty + if the key is deleted; + * W_NEWNODE - the key was found: a new node/tree is returned + without that key. +*/ +typedef enum {W_ERROR, W_NOT_FOUND, W_EMPTY, W_NEWNODE} map_without_t; + + +/* Low-level iterator protocol type. + + * I_ITEM - a new item has been yielded; + * I_END - the whole tree was visited (similar to StopIteration). +*/ +typedef enum {I_ITEM, I_END} map_iter_t; + + +#define HAMT_ARRAY_NODE_SIZE 32 + + +typedef struct { + PyObject_HEAD + MapNode *a_array[HAMT_ARRAY_NODE_SIZE]; + Py_ssize_t a_count; + uint64_t a_mutid; +} MapNode_Array; + + +typedef struct { + PyObject_VAR_HEAD + uint64_t b_mutid; + uint32_t b_bitmap; + PyObject *b_array[1]; +} MapNode_Bitmap; + + +typedef struct { + PyObject_VAR_HEAD + uint64_t c_mutid; + int32_t c_hash; + PyObject *c_array[1]; +} MapNode_Collision; + + +static volatile uint64_t mutid_counter = 1; + +static MapNode_Bitmap *_empty_bitmap_node; + + +/* Create a new HAMT immutable mapping. */ +static MapObject * +map_new(void); + +/* Return a new collection based on "o", but with an additional + key/val pair. */ +static MapObject * +map_assoc(MapObject *o, PyObject *key, PyObject *val); + +/* Return a new collection based on "o", but without "key". */ +static MapObject * +map_without(MapObject *o, PyObject *key); + +/* Check if "v" is equal to "w". + + Return: + - 0: v != w + - 1: v == w + - -1: An error occurred. +*/ +static int +map_eq(BaseMapObject *v, BaseMapObject *w); + +static map_find_t +map_find(BaseMapObject *o, PyObject *key, PyObject **val); + +/* Return the size of "o"; equivalent of "len(o)". */ +static Py_ssize_t +map_len(BaseMapObject *o); + + +static MapObject * +map_alloc(void); + +static MapNode * +map_node_assoc(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf, + uint64_t mutid); + +static map_without_t +map_node_without(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, + MapNode **new_node, + uint64_t mutid); + +static map_find_t +map_node_find(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val); + +static int +map_node_dump(MapNode *node, + _PyUnicodeWriter *writer, int level); + +static MapNode * +map_node_array_new(Py_ssize_t, uint64_t mutid); + +static MapNode * +map_node_collision_new(int32_t hash, Py_ssize_t size, uint64_t mutid); + +static inline Py_ssize_t +map_node_collision_count(MapNode_Collision *node); + +static int +map_node_update(uint64_t mutid, + PyObject *seq, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count); + + +static int +map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src); + +static MapObject * +map_update(uint64_t mutid, MapObject *o, PyObject *src); + + +#ifdef NDEBUG +static void +_map_node_array_validate(void *o) +{ + assert(IS_ARRAY_NODE(o)); + MapNode_Array *node = (MapNode_Array*)(o); + Py_ssize_t i = 0, count = 0; + for (; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] != NULL) { + count++; + } + } + assert(count == node->a_count); +} + +#define VALIDATE_ARRAY_NODE(NODE) \ + do { _map_node_array_validate(NODE); } while (0); +#else +#define VALIDATE_ARRAY_NODE(NODE) +#endif + + +/* Returns -1 on error */ +static inline int32_t +map_hash(PyObject *o) +{ + Py_hash_t hash = PyObject_Hash(o); + +#if SIZEOF_PY_HASH_T <= 4 + return hash; +#else + if (hash == -1) { + /* exception */ + return -1; + } + + /* While it's suboptimal to reduce Python's 64 bit hash to + 32 bits via XOR, it seems that the resulting hash function + is good enough (this is also how Long type is hashed in Java.) + Storing 10, 100, 1000 Python strings results in a relatively + shallow and uniform tree structure. + + Please don't change this hashing algorithm, as there are many + tests that test some exact tree shape to cover all code paths. + */ + int32_t xored = (int32_t)(hash & 0xffffffffl) ^ (int32_t)(hash >> 32); + return xored == -1 ? -2 : xored; +#endif +} + +static inline uint32_t +map_mask(int32_t hash, uint32_t shift) +{ + return (((uint32_t)hash >> shift) & 0x01f); +} + +static inline uint32_t +map_bitpos(int32_t hash, uint32_t shift) +{ + return (uint32_t)1 << map_mask(hash, shift); +} + +static inline uint32_t +map_bitcount(uint32_t i) +{ + /* We could use native popcount instruction but that would + require to either add configure flags to enable SSE4.2 + support or to detect it dynamically. Otherwise, we have + a risk of CPython not working properly on older hardware. + + In practice, there's no observable difference in + performance between using a popcount instruction or the + following fallback code. + + The algorithm is copied from: + https://graphics.stanford.edu/~seander/bithacks.html + */ + i = i - ((i >> 1) & 0x55555555); + i = (i & 0x33333333) + ((i >> 2) & 0x33333333); + return (((i + (i >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; +} + +static inline uint32_t +map_bitindex(uint32_t bitmap, uint32_t bit) +{ + return map_bitcount(bitmap & (bit - 1)); +} + + +/////////////////////////////////// Dump Helpers + +static int +_map_dump_ident(_PyUnicodeWriter *writer, int level) +{ + /* Write `' ' * level` to the `writer` */ + PyObject *str = NULL; + PyObject *num = NULL; + PyObject *res = NULL; + int ret = -1; + + str = PyUnicode_FromString(" "); + if (str == NULL) { + goto error; + } + + num = PyLong_FromLong((long)level); + if (num == NULL) { + goto error; + } + + res = PyNumber_Multiply(str, num); + if (res == NULL) { + goto error; + } + + ret = _PyUnicodeWriter_WriteStr(writer, res); + +error: + Py_XDECREF(res); + Py_XDECREF(str); + Py_XDECREF(num); + return ret; +} + +static int +_map_dump_format(_PyUnicodeWriter *writer, const char *format, ...) +{ + /* A convenient helper combining _PyUnicodeWriter_WriteStr and + PyUnicode_FromFormatV. + */ + PyObject* msg; + int ret; + + va_list vargs; +#ifdef HAVE_STDARG_PROTOTYPES + va_start(vargs, format); +#else + va_start(vargs); +#endif + msg = PyUnicode_FromFormatV(format, vargs); + va_end(vargs); + + if (msg == NULL) { + return -1; + } + + ret = _PyUnicodeWriter_WriteStr(writer, msg); + Py_DECREF(msg); + return ret; +} + +/////////////////////////////////// Bitmap Node + + +static MapNode * +map_node_bitmap_new(Py_ssize_t size, uint64_t mutid) +{ + /* Create a new bitmap node of size 'size' */ + + MapNode_Bitmap *node; + Py_ssize_t i; + + assert(size >= 0); + assert(size % 2 == 0); + + if (size == 0 && _empty_bitmap_node != NULL && mutid == 0) { + Py_INCREF(_empty_bitmap_node); + return (MapNode *)_empty_bitmap_node; + } + + /* No freelist; allocate a new bitmap node */ + node = PyObject_GC_NewVar( + MapNode_Bitmap, &_Map_BitmapNode_Type, size); + if (node == NULL) { + return NULL; + } + + Py_SIZE(node) = size; + + for (i = 0; i < size; i++) { + node->b_array[i] = NULL; + } + + node->b_bitmap = 0; + node->b_mutid = mutid; + + PyObject_GC_Track(node); + + if (size == 0 && _empty_bitmap_node == NULL && mutid == 0) { + /* Since bitmap nodes are immutable, we can cache the instance + for size=0 and reuse it whenever we need an empty bitmap node. + */ + _empty_bitmap_node = node; + Py_INCREF(_empty_bitmap_node); + } + + return (MapNode *)node; +} + +static inline Py_ssize_t +map_node_bitmap_count(MapNode_Bitmap *node) +{ + return Py_SIZE(node) / 2; +} + +static MapNode_Bitmap * +map_node_bitmap_clone(MapNode_Bitmap *node, uint64_t mutid) +{ + /* Clone a bitmap node; return a new one with the same child notes. */ + + MapNode_Bitmap *clone; + Py_ssize_t i; + + clone = (MapNode_Bitmap *)map_node_bitmap_new( + Py_SIZE(node), mutid); + if (clone == NULL) { + return NULL; + } + + for (i = 0; i < Py_SIZE(node); i++) { + Py_XINCREF(node->b_array[i]); + clone->b_array[i] = node->b_array[i]; + } + + clone->b_bitmap = node->b_bitmap; + return clone; +} + +static MapNode_Bitmap * +map_node_bitmap_clone_without(MapNode_Bitmap *o, uint32_t bit, uint64_t mutid) +{ + assert(bit & o->b_bitmap); + assert(map_node_bitmap_count(o) > 1); + + MapNode_Bitmap *new = (MapNode_Bitmap *)map_node_bitmap_new( + Py_SIZE(o) - 2, mutid); + if (new == NULL) { + return NULL; + } + + uint32_t idx = map_bitindex(o->b_bitmap, bit); + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + uint32_t i; + + for (i = 0; i < key_idx; i++) { + Py_XINCREF(o->b_array[i]); + new->b_array[i] = o->b_array[i]; + } + + assert(Py_SIZE(o) >= 0 && Py_SIZE(o) <= 32); + for (i = val_idx + 1; i < (uint32_t)Py_SIZE(o); i++) { + Py_XINCREF(o->b_array[i]); + new->b_array[i - 2] = o->b_array[i]; + } + + new->b_bitmap = o->b_bitmap & ~bit; + return new; +} + +static MapNode * +map_node_new_bitmap_or_collision(uint32_t shift, + PyObject *key1, PyObject *val1, + int32_t key2_hash, + PyObject *key2, PyObject *val2, + uint64_t mutid) +{ + /* Helper method. Creates a new node for key1/val and key2/val2 + pairs. + + If key1 hash is equal to the hash of key2, a Collision node + will be created. If they are not equal, a Bitmap node is + created. + */ + + int32_t key1_hash = map_hash(key1); + if (key1_hash == -1) { + return NULL; + } + + if (key1_hash == key2_hash) { + MapNode_Collision *n; + n = (MapNode_Collision *)map_node_collision_new(key1_hash, 4, mutid); + if (n == NULL) { + return NULL; + } + + Py_INCREF(key1); + n->c_array[0] = key1; + Py_INCREF(val1); + n->c_array[1] = val1; + + Py_INCREF(key2); + n->c_array[2] = key2; + Py_INCREF(val2); + n->c_array[3] = val2; + + return (MapNode *)n; + } + else { + int added_leaf = 0; + MapNode *n = map_node_bitmap_new(0, mutid); + if (n == NULL) { + return NULL; + } + + MapNode *n2 = map_node_assoc( + n, shift, key1_hash, key1, val1, &added_leaf, mutid); + Py_DECREF(n); + if (n2 == NULL) { + return NULL; + } + + n = map_node_assoc( + n2, shift, key2_hash, key2, val2, &added_leaf, mutid); + Py_DECREF(n2); + if (n == NULL) { + return NULL; + } + + return n; + } +} + +static MapNode * +map_node_bitmap_assoc(MapNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf, + uint64_t mutid) +{ + /* assoc operation for bitmap nodes. + + Return: a new node, or self if key/val already is in the + collection. + + 'added_leaf' is later used in 'map_assoc' to determine if + `map.set(key, val)` increased the size of the collection. + */ + + uint32_t bit = map_bitpos(hash, shift); + uint32_t idx = map_bitindex(self->b_bitmap, bit); + + /* Bitmap node layout: + + +------+------+------+------+ --- +------+------+ + | key1 | val1 | key2 | val2 | ... | keyN | valN | + +------+------+------+------+ --- +------+------+ + where `N < Py_SIZE(node)`. + + The `node->b_bitmap` field is a bitmap. For a given + `(shift, hash)` pair we can determine: + + - If this node has the corresponding key/val slots. + - The index of key/val slots. + */ + + if (self->b_bitmap & bit) { + /* The key is set in this node */ + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + + assert(val_idx < (size_t)Py_SIZE(self)); + + PyObject *key_or_null = self->b_array[key_idx]; + PyObject *val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* key is NULL. This means that we have a few keys + that have the same (hash, shift) pair. */ + + assert(val_or_node != NULL); + + MapNode *sub_node = map_node_assoc( + (MapNode *)val_or_node, + shift + 5, hash, key, val, added_leaf, + mutid); + if (sub_node == NULL) { + return NULL; + } + + if (val_or_node == (PyObject *)sub_node) { + Py_DECREF(sub_node); + Py_INCREF(self); + return (MapNode *)self; + } + + if (mutid != 0 && self->b_mutid == mutid) { + Py_SETREF(self->b_array[val_idx], (PyObject*)sub_node); + Py_INCREF(self); + return (MapNode *)self; + } + else { + MapNode_Bitmap *ret = map_node_bitmap_clone(self, mutid); + if (ret == NULL) { + return NULL; + } + Py_SETREF(ret->b_array[val_idx], (PyObject*)sub_node); + return (MapNode *)ret; + } + } + + assert(key != NULL); + /* key is not NULL. This means that we have only one other + key in this collection that matches our hash for this shift. */ + + int comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ); + if (comp_err < 0) { /* exception in __eq__ */ + return NULL; + } + if (comp_err == 1) { /* key == key_or_null */ + if (val == val_or_node) { + /* we already have the same key/val pair; return self. */ + Py_INCREF(self); + return (MapNode *)self; + } + + /* We're setting a new value for the key we had before. */ + if (mutid != 0 && self->b_mutid == mutid) { + /* We've been mutating this node before: update inplace. */ + Py_INCREF(val); + Py_SETREF(self->b_array[val_idx], val); + Py_INCREF(self); + return (MapNode *)self; + } + else { + /* Make a new bitmap node with a replaced value, + and return it. */ + MapNode_Bitmap *ret = map_node_bitmap_clone(self, mutid); + if (ret == NULL) { + return NULL; + } + Py_INCREF(val); + Py_SETREF(ret->b_array[val_idx], val); + return (MapNode *)ret; + } + } + + /* It's a new key, and it has the same index as *one* another key. + We have a collision. We need to create a new node which will + combine the existing key and the key we're adding. + + `map_node_new_bitmap_or_collision` will either create a new + Collision node if the keys have identical hashes, or + a new Bitmap node. + */ + MapNode *sub_node = map_node_new_bitmap_or_collision( + shift + 5, + key_or_null, val_or_node, /* existing key/val */ + hash, + key, val, /* new key/val */ + self->b_mutid + ); + if (sub_node == NULL) { + return NULL; + } + + if (mutid != 0 && self->b_mutid == mutid) { + Py_SETREF(self->b_array[key_idx], NULL); + Py_SETREF(self->b_array[val_idx], (PyObject *)sub_node); + Py_INCREF(self); + + *added_leaf = 1; + return (MapNode *)self; + } + else { + MapNode_Bitmap *ret = map_node_bitmap_clone(self, mutid); + if (ret == NULL) { + Py_DECREF(sub_node); + return NULL; + } + Py_SETREF(ret->b_array[key_idx], NULL); + Py_SETREF(ret->b_array[val_idx], (PyObject *)sub_node); + + *added_leaf = 1; + return (MapNode *)ret; + } + } + else { + /* There was no key before with the same (shift,hash). */ + + uint32_t n = map_bitcount(self->b_bitmap); + + if (n >= 16) { + /* When we have a situation where we want to store more + than 16 nodes at one level of the tree, we no longer + want to use the Bitmap node with bitmap encoding. + + Instead we start using an Array node, which has + simpler (faster) implementation at the expense of + having prealocated 32 pointers for its keys/values + pairs. + + Small map objects (<30 keys) usually don't have any + Array nodes at all. Between ~30 and ~400 keys map + objects usually have one Array node, and usually it's + a root node. + */ + + uint32_t jdx = map_mask(hash, shift); + /* 'jdx' is the index of where the new key should be added + in the new Array node we're about to create. */ + + MapNode *empty = NULL; + MapNode_Array *new_node = NULL; + MapNode *res = NULL; + + /* Create a new Array node. */ + new_node = (MapNode_Array *)map_node_array_new(n + 1, mutid); + if (new_node == NULL) { + goto fin; + } + + /* Create an empty bitmap node for the next + map_node_assoc call. */ + empty = map_node_bitmap_new(0, mutid); + if (empty == NULL) { + goto fin; + } + + /* Make a new bitmap node for the key/val we're adding. + Set that bitmap node to new-array-node[jdx]. */ + new_node->a_array[jdx] = map_node_assoc( + empty, shift + 5, hash, key, val, added_leaf, mutid); + if (new_node->a_array[jdx] == NULL) { + goto fin; + } + + /* Copy existing key/value pairs from the current Bitmap + node to the new Array node we've just created. */ + Py_ssize_t i, j; + for (i = 0, j = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (((self->b_bitmap >> i) & 1) != 0) { + /* Ensure we don't accidentally override `jdx` element + we set few lines above. + */ + assert(new_node->a_array[i] == NULL); + + if (self->b_array[j] == NULL) { + new_node->a_array[i] = + (MapNode *)self->b_array[j + 1]; + Py_INCREF(new_node->a_array[i]); + } + else { + int32_t rehash = map_hash(self->b_array[j]); + if (rehash == -1) { + goto fin; + } + + new_node->a_array[i] = map_node_assoc( + empty, shift + 5, + rehash, + self->b_array[j], + self->b_array[j + 1], + added_leaf, + mutid); + + if (new_node->a_array[i] == NULL) { + goto fin; + } + } + j += 2; + } + } + + VALIDATE_ARRAY_NODE(new_node) + + /* That's it! */ + res = (MapNode *)new_node; + + fin: + Py_XDECREF(empty); + if (res == NULL) { + Py_XDECREF(new_node); + } + return res; + } + else { + /* We have less than 16 keys at this level; let's just + create a new bitmap node out of this node with the + new key/val pair added. */ + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + uint32_t i; + + *added_leaf = 1; + + /* Allocate new Bitmap node which can have one more key/val + pair in addition to what we have already. */ + MapNode_Bitmap *new_node = + (MapNode_Bitmap *)map_node_bitmap_new(2 * (n + 1), mutid); + if (new_node == NULL) { + return NULL; + } + + /* Copy all keys/values that will be before the new key/value + we are adding. */ + for (i = 0; i < key_idx; i++) { + Py_XINCREF(self->b_array[i]); + new_node->b_array[i] = self->b_array[i]; + } + + /* Set the new key/value to the new Bitmap node. */ + Py_INCREF(key); + new_node->b_array[key_idx] = key; + Py_INCREF(val); + new_node->b_array[val_idx] = val; + + /* Copy all keys/values that will be after the new key/value + we are adding. */ + assert(Py_SIZE(self) >= 0 && Py_SIZE(self) <= 32); + for (i = key_idx; i < (uint32_t)Py_SIZE(self); i++) { + Py_XINCREF(self->b_array[i]); + new_node->b_array[i + 2] = self->b_array[i]; + } + + new_node->b_bitmap = self->b_bitmap | bit; + return (MapNode *)new_node; + } + } +} + +static map_without_t +map_node_bitmap_without(MapNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, + MapNode **new_node, + uint64_t mutid) +{ + uint32_t bit = map_bitpos(hash, shift); + if ((self->b_bitmap & bit) == 0) { + return W_NOT_FOUND; + } + + uint32_t idx = map_bitindex(self->b_bitmap, bit); + + uint32_t key_idx = 2 * idx; + uint32_t val_idx = key_idx + 1; + + PyObject *key_or_null = self->b_array[key_idx]; + PyObject *val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* key == NULL means that 'value' is another tree node. */ + + MapNode *sub_node = NULL; + MapNode_Bitmap *target = NULL; + + map_without_t res = map_node_without( + (MapNode *)val_or_node, + shift + 5, hash, key, &sub_node, + mutid); + + switch (res) { + case W_EMPTY: + /* It's impossible for us to receive a W_EMPTY here: + + - Array nodes are converted to Bitmap nodes when + we delete 16th item from them; + + - Collision nodes are converted to Bitmap when + there is one item in them; + + - Bitmap node's without() inlines single-item + sub-nodes. + + So in no situation we can have a single-item + Bitmap child of another Bitmap node. + */ + abort(); + + case W_NEWNODE: { + assert(sub_node != NULL); + + if (IS_BITMAP_NODE(sub_node)) { + MapNode_Bitmap *sub_tree = (MapNode_Bitmap *)sub_node; + if (map_node_bitmap_count(sub_tree) == 1 && + sub_tree->b_array[0] != NULL) + { + /* A bitmap node with one key/value pair. Just + merge it into this node. + + Note that we don't inline Bitmap nodes that + have a NULL key -- those nodes point to another + tree level, and we cannot simply move tree levels + up or down. + */ + + if (mutid != 0 && self->b_mutid == mutid) { + target = self; + Py_INCREF(target); + } + else { + target = map_node_bitmap_clone(self, mutid); + if (target == NULL) { + Py_DECREF(sub_node); + return W_ERROR; + } + } + + PyObject *key = sub_tree->b_array[0]; + PyObject *val = sub_tree->b_array[1]; + + Py_INCREF(key); + Py_XSETREF(target->b_array[key_idx], key); + Py_INCREF(val); + Py_SETREF(target->b_array[val_idx], val); + + Py_DECREF(sub_tree); + + *new_node = (MapNode *)target; + return W_NEWNODE; + } + } + +#ifdef NDEBUG + /* Ensure that Collision.without implementation + converts to Bitmap nodes itself. + */ + if (IS_COLLISION_NODE(sub_node)) { + assert(map_node_collision_count( + (MapNode_Collision*)sub_node) > 1); + } +#endif + + if (mutid != 0 && self->b_mutid == mutid) { + target = self; + Py_INCREF(target); + } + else { + target = map_node_bitmap_clone(self, mutid); + if (target == NULL) { + return W_ERROR; + } + } + + Py_SETREF(target->b_array[val_idx], + (PyObject *)sub_node); /* borrow */ + + *new_node = (MapNode *)target; + return W_NEWNODE; + } + + case W_ERROR: + case W_NOT_FOUND: + assert(sub_node == NULL); + return res; + + default: + abort(); + } + } + else { + /* We have a regular key/value pair */ + + int cmp = PyObject_RichCompareBool(key_or_null, key, Py_EQ); + if (cmp < 0) { + return W_ERROR; + } + if (cmp == 0) { + return W_NOT_FOUND; + } + + if (map_node_bitmap_count(self) == 1) { + return W_EMPTY; + } + + *new_node = (MapNode *) + map_node_bitmap_clone_without(self, bit, mutid); + if (*new_node == NULL) { + return W_ERROR; + } + + return W_NEWNODE; + } +} + +static map_find_t +map_node_bitmap_find(MapNode_Bitmap *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup a key in a Bitmap node. */ + + uint32_t bit = map_bitpos(hash, shift); + uint32_t idx; + uint32_t key_idx; + uint32_t val_idx; + PyObject *key_or_null; + PyObject *val_or_node; + int comp_err; + + if ((self->b_bitmap & bit) == 0) { + return F_NOT_FOUND; + } + + idx = map_bitindex(self->b_bitmap, bit); + key_idx = idx * 2; + val_idx = key_idx + 1; + + assert(val_idx < (size_t)Py_SIZE(self)); + + key_or_null = self->b_array[key_idx]; + val_or_node = self->b_array[val_idx]; + + if (key_or_null == NULL) { + /* There are a few keys that have the same hash at the current shift + that match our key. Dispatch the lookup further down the tree. */ + assert(val_or_node != NULL); + return map_node_find((MapNode *)val_or_node, + shift + 5, hash, key, val); + } + + /* We have only one key -- a potential match. Let's compare if the + key we are looking at is equal to the key we are looking for. */ + assert(key != NULL); + comp_err = PyObject_RichCompareBool(key, key_or_null, Py_EQ); + if (comp_err < 0) { /* exception in __eq__ */ + return F_ERROR; + } + if (comp_err == 1) { /* key == key_or_null */ + *val = val_or_node; + return F_FOUND; + } + + return F_NOT_FOUND; +} + +static int +map_node_bitmap_traverse(MapNode_Bitmap *self, visitproc visit, void *arg) +{ + /* Bitmap's tp_traverse */ + + Py_ssize_t i; + + for (i = Py_SIZE(self); --i >= 0; ) { + Py_VISIT(self->b_array[i]); + } + + return 0; +} + +static void +map_node_bitmap_dealloc(MapNode_Bitmap *self) +{ + /* Bitmap's tp_dealloc */ + + Py_ssize_t len = Py_SIZE(self); + Py_ssize_t i; + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + if (len > 0) { + i = len; + while (--i >= 0) { + Py_XDECREF(self->b_array[i]); + } + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +static int +map_node_bitmap_dump(MapNode_Bitmap *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Bitmap nodes. */ + + Py_ssize_t i; + PyObject *tmp1; + PyObject *tmp2; + + if (_map_dump_ident(writer, level + 1)) { + goto error; + } + + if (_map_dump_format(writer, "BitmapNode(size=%zd count=%zd ", + Py_SIZE(node), Py_SIZE(node) / 2)) + { + goto error; + } + + tmp1 = PyLong_FromUnsignedLong(node->b_bitmap); + if (tmp1 == NULL) { + goto error; + } + tmp2 = _PyLong_Format(tmp1, 2); + Py_DECREF(tmp1); + if (tmp2 == NULL) { + goto error; + } + if (_map_dump_format(writer, "bitmap=%S id=%p):\n", tmp2, node)) { + Py_DECREF(tmp2); + goto error; + } + Py_DECREF(tmp2); + + for (i = 0; i < Py_SIZE(node); i += 2) { + PyObject *key_or_null = node->b_array[i]; + PyObject *val_or_node = node->b_array[i + 1]; + + if (_map_dump_ident(writer, level + 2)) { + goto error; + } + + if (key_or_null == NULL) { + if (_map_dump_format(writer, "NULL:\n")) { + goto error; + } + + if (map_node_dump((MapNode *)val_or_node, + writer, level + 2)) + { + goto error; + } + } + else { + if (_map_dump_format(writer, "%R: %R", key_or_null, + val_or_node)) + { + goto error; + } + } + + if (_map_dump_format(writer, "\n")) { + goto error; + } + } + + return 0; +error: + return -1; +} + + +/////////////////////////////////// Collision Node + + +static MapNode * +map_node_collision_new(int32_t hash, Py_ssize_t size, uint64_t mutid) +{ + /* Create a new Collision node. */ + + MapNode_Collision *node; + Py_ssize_t i; + + assert(size >= 4); + assert(size % 2 == 0); + + node = PyObject_GC_NewVar( + MapNode_Collision, &_Map_CollisionNode_Type, size); + if (node == NULL) { + return NULL; + } + + for (i = 0; i < size; i++) { + node->c_array[i] = NULL; + } + + Py_SIZE(node) = size; + node->c_hash = hash; + + node->c_mutid = mutid; + + PyObject_GC_Track(node); + return (MapNode *)node; +} + +static map_find_t +map_node_collision_find_index(MapNode_Collision *self, PyObject *key, + Py_ssize_t *idx) +{ + /* Lookup `key` in the Collision node `self`. Set the index of the + found key to 'idx'. */ + + Py_ssize_t i; + PyObject *el; + + for (i = 0; i < Py_SIZE(self); i += 2) { + el = self->c_array[i]; + + assert(el != NULL); + int cmp = PyObject_RichCompareBool(key, el, Py_EQ); + if (cmp < 0) { + return F_ERROR; + } + if (cmp == 1) { + *idx = i; + return F_FOUND; + } + } + + return F_NOT_FOUND; +} + +static MapNode * +map_node_collision_assoc(MapNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf, + uint64_t mutid) +{ + /* Set a new key to this level (currently a Collision node) + of the tree. */ + + if (hash == self->c_hash) { + /* The hash of the 'key' we are adding matches the hash of + other keys in this Collision node. */ + + Py_ssize_t key_idx = -1; + map_find_t found; + MapNode_Collision *new_node; + Py_ssize_t i; + + /* Let's try to lookup the new 'key', maybe we already have it. */ + found = map_node_collision_find_index(self, key, &key_idx); + switch (found) { + case F_ERROR: + /* Exception. */ + return NULL; + + case F_NOT_FOUND: + /* This is a totally new key. Clone the current node, + add a new key/value to the cloned node. */ + + new_node = (MapNode_Collision *)map_node_collision_new( + self->c_hash, Py_SIZE(self) + 2, mutid); + if (new_node == NULL) { + return NULL; + } + + for (i = 0; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new_node->c_array[i] = self->c_array[i]; + } + + Py_INCREF(key); + new_node->c_array[i] = key; + Py_INCREF(val); + new_node->c_array[i + 1] = val; + + *added_leaf = 1; + return (MapNode *)new_node; + + case F_FOUND: + /* There's a key which is equal to the key we are adding. */ + + assert(key_idx >= 0); + assert(key_idx < Py_SIZE(self)); + Py_ssize_t val_idx = key_idx + 1; + + if (self->c_array[val_idx] == val) { + /* We're setting a key/value pair that's already set. */ + Py_INCREF(self); + return (MapNode *)self; + } + + /* We need to replace old value for the key with + a new value. */ + + if (mutid != 0 && self->c_mutid == mutid) { + new_node = self; + Py_INCREF(self); + } + else { + /* Create a new Collision node.*/ + new_node = (MapNode_Collision *)map_node_collision_new( + self->c_hash, Py_SIZE(self), mutid); + if (new_node == NULL) { + return NULL; + } + + /* Copy all elements of the old node to the new one. */ + for (i = 0; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new_node->c_array[i] = self->c_array[i]; + } + } + + /* Replace the old value with the new value for the our key. */ + Py_DECREF(new_node->c_array[val_idx]); + Py_INCREF(val); + new_node->c_array[val_idx] = val; + + return (MapNode *)new_node; + + default: + abort(); + } + } + else { + /* The hash of the new key is different from the hash that + all keys of this Collision node have. + + Create a Bitmap node inplace with two children: + key/value pair that we're adding, and the Collision node + we're replacing on this tree level. + */ + + MapNode_Bitmap *new_node; + MapNode *assoc_res; + + new_node = (MapNode_Bitmap *)map_node_bitmap_new(2, mutid); + if (new_node == NULL) { + return NULL; + } + new_node->b_bitmap = map_bitpos(self->c_hash, shift); + Py_INCREF(self); + new_node->b_array[1] = (PyObject*) self; + + assoc_res = map_node_bitmap_assoc( + new_node, shift, hash, key, val, added_leaf, mutid); + Py_DECREF(new_node); + return assoc_res; + } +} + +static inline Py_ssize_t +map_node_collision_count(MapNode_Collision *node) +{ + return Py_SIZE(node) / 2; +} + +static map_without_t +map_node_collision_without(MapNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, + MapNode **new_node, + uint64_t mutid) +{ + if (hash != self->c_hash) { + return W_NOT_FOUND; + } + + Py_ssize_t key_idx = -1; + map_find_t found = map_node_collision_find_index(self, key, &key_idx); + + switch (found) { + case F_ERROR: + return W_ERROR; + + case F_NOT_FOUND: + return W_NOT_FOUND; + + case F_FOUND: + assert(key_idx >= 0); + assert(key_idx < Py_SIZE(self)); + + Py_ssize_t new_count = map_node_collision_count(self) - 1; + + if (new_count == 0) { + /* The node has only one key/value pair and it's for the + key we're trying to delete. So a new node will be empty + after the removal. + */ + return W_EMPTY; + } + + if (new_count == 1) { + /* The node has two keys, and after deletion the + new Collision node would have one. Collision nodes + with one key shouldn't exist, so convert it to a + Bitmap node. + */ + MapNode_Bitmap *node = (MapNode_Bitmap *) + map_node_bitmap_new(2, mutid); + if (node == NULL) { + return W_ERROR; + } + + if (key_idx == 0) { + Py_INCREF(self->c_array[2]); + node->b_array[0] = self->c_array[2]; + Py_INCREF(self->c_array[3]); + node->b_array[1] = self->c_array[3]; + } + else { + assert(key_idx == 2); + Py_INCREF(self->c_array[0]); + node->b_array[0] = self->c_array[0]; + Py_INCREF(self->c_array[1]); + node->b_array[1] = self->c_array[1]; + } + + node->b_bitmap = map_bitpos(hash, shift); + + *new_node = (MapNode *)node; + return W_NEWNODE; + } + + /* Allocate a new Collision node with capacity for one + less key/value pair */ + MapNode_Collision *new = (MapNode_Collision *) + map_node_collision_new( + self->c_hash, Py_SIZE(self) - 2, mutid); + if (new == NULL) { + return W_ERROR; + } + + /* Copy all other keys from `self` to `new` */ + Py_ssize_t i; + for (i = 0; i < key_idx; i++) { + Py_INCREF(self->c_array[i]); + new->c_array[i] = self->c_array[i]; + } + for (i = key_idx + 2; i < Py_SIZE(self); i++) { + Py_INCREF(self->c_array[i]); + new->c_array[i - 2] = self->c_array[i]; + } + + *new_node = (MapNode*)new; + return W_NEWNODE; + + default: + abort(); + } +} + +static map_find_t +map_node_collision_find(MapNode_Collision *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup `key` in the Collision node `self`. Set the value + for the found key to 'val'. */ + + Py_ssize_t idx = -1; + map_find_t res; + + res = map_node_collision_find_index(self, key, &idx); + if (res == F_ERROR || res == F_NOT_FOUND) { + return res; + } + + assert(idx >= 0); + assert(idx + 1 < Py_SIZE(self)); + + *val = self->c_array[idx + 1]; + assert(*val != NULL); + + return F_FOUND; +} + + +static int +map_node_collision_traverse(MapNode_Collision *self, + visitproc visit, void *arg) +{ + /* Collision's tp_traverse */ + + Py_ssize_t i; + + for (i = Py_SIZE(self); --i >= 0; ) { + Py_VISIT(self->c_array[i]); + } + + return 0; +} + +static void +map_node_collision_dealloc(MapNode_Collision *self) +{ + /* Collision's tp_dealloc */ + + Py_ssize_t len = Py_SIZE(self); + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + if (len > 0) { + + while (--len >= 0) { + Py_XDECREF(self->c_array[len]); + } + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +static int +map_node_collision_dump(MapNode_Collision *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Collision nodes. */ + + Py_ssize_t i; + + if (_map_dump_ident(writer, level + 1)) { + goto error; + } + + if (_map_dump_format(writer, "CollisionNode(size=%zd id=%p):\n", + Py_SIZE(node), node)) + { + goto error; + } + + for (i = 0; i < Py_SIZE(node); i += 2) { + PyObject *key = node->c_array[i]; + PyObject *val = node->c_array[i + 1]; + + if (_map_dump_ident(writer, level + 2)) { + goto error; + } + + if (_map_dump_format(writer, "%R: %R\n", key, val)) { + goto error; + } + } + + return 0; +error: + return -1; +} + + +/////////////////////////////////// Array Node + + +static MapNode * +map_node_array_new(Py_ssize_t count, uint64_t mutid) +{ + Py_ssize_t i; + + MapNode_Array *node = PyObject_GC_New( + MapNode_Array, &_Map_ArrayNode_Type); + if (node == NULL) { + return NULL; + } + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + node->a_array[i] = NULL; + } + + node->a_count = count; + node->a_mutid = mutid; + + PyObject_GC_Track(node); + return (MapNode *)node; +} + +static MapNode_Array * +map_node_array_clone(MapNode_Array *node, uint64_t mutid) +{ + MapNode_Array *clone; + Py_ssize_t i; + + VALIDATE_ARRAY_NODE(node) + + /* Create a new Array node. */ + clone = (MapNode_Array *)map_node_array_new(node->a_count, mutid); + if (clone == NULL) { + return NULL; + } + + /* Copy all elements from the current Array node to the new one. */ + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XINCREF(node->a_array[i]); + clone->a_array[i] = node->a_array[i]; + } + + clone->a_mutid = mutid; + + VALIDATE_ARRAY_NODE(clone) + return clone; +} + +static MapNode * +map_node_array_assoc(MapNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf, + uint64_t mutid) +{ + /* Set a new key to this level (currently a Collision node) + of the tree. + + Array nodes don't store values, they can only point to + other nodes. They are simple arrays of 32 BaseNode pointers/ + */ + + uint32_t idx = map_mask(hash, shift); + MapNode *node = self->a_array[idx]; + MapNode *child_node; + MapNode_Array *new_node; + Py_ssize_t i; + + if (node == NULL) { + /* There's no child node for the given hash. Create a new + Bitmap node for this key. */ + + MapNode_Bitmap *empty = NULL; + + /* Get an empty Bitmap node to work with. */ + empty = (MapNode_Bitmap *)map_node_bitmap_new(0, mutid); + if (empty == NULL) { + return NULL; + } + + /* Set key/val to the newly created empty Bitmap, thus + creating a new Bitmap node with our key/value pair. */ + child_node = map_node_bitmap_assoc( + empty, + shift + 5, hash, key, val, added_leaf, mutid); + Py_DECREF(empty); + if (child_node == NULL) { + return NULL; + } + + if (mutid != 0 && self->a_mutid == mutid) { + new_node = self; + Py_INCREF(self); + } + else { + /* Create a new Array node. */ + new_node = (MapNode_Array *)map_node_array_new( + self->a_count + 1, mutid); + if (new_node == NULL) { + Py_DECREF(child_node); + return NULL; + } + + /* Copy all elements from the current Array node to the + new one. */ + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XINCREF(self->a_array[i]); + new_node->a_array[i] = self->a_array[i]; + } + } + + assert(new_node->a_array[idx] == NULL); + new_node->a_array[idx] = child_node; /* borrow */ + VALIDATE_ARRAY_NODE(new_node) + } + else { + /* There's a child node for the given hash. + Set the key to it./ */ + + child_node = map_node_assoc( + node, shift + 5, hash, key, val, added_leaf, mutid); + if (child_node == NULL) { + return NULL; + } + else if (child_node == (MapNode *)self) { + Py_DECREF(child_node); + return (MapNode *)self; + } + + if (mutid != 0 && self->a_mutid == mutid) { + new_node = self; + Py_INCREF(self); + } + else { + new_node = map_node_array_clone(self, mutid); + } + + if (new_node == NULL) { + Py_DECREF(child_node); + return NULL; + } + + Py_SETREF(new_node->a_array[idx], child_node); /* borrow */ + VALIDATE_ARRAY_NODE(new_node) + } + + return (MapNode *)new_node; +} + +static map_without_t +map_node_array_without(MapNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, + MapNode **new_node, + uint64_t mutid) +{ + uint32_t idx = map_mask(hash, shift); + MapNode *node = self->a_array[idx]; + + if (node == NULL) { + return W_NOT_FOUND; + } + + MapNode *sub_node = NULL; + MapNode_Array *target = NULL; + map_without_t res = map_node_without( + (MapNode *)node, + shift + 5, hash, key, &sub_node, mutid); + + switch (res) { + case W_NOT_FOUND: + case W_ERROR: + assert(sub_node == NULL); + return res; + + case W_NEWNODE: { + /* We need to replace a node at the `idx` index. + Clone this node and replace. + */ + assert(sub_node != NULL); + + if (mutid != 0 && self->a_mutid == mutid) { + target = self; + Py_INCREF(self); + } + else { + target = map_node_array_clone(self, mutid); + if (target == NULL) { + Py_DECREF(sub_node); + return W_ERROR; + } + } + + Py_SETREF(target->a_array[idx], sub_node); /* borrow */ + *new_node = (MapNode*)target; /* borrow */ + return W_NEWNODE; + } + + case W_EMPTY: { + assert(sub_node == NULL); + /* We need to remove a node at the `idx` index. + Calculate the size of the replacement Array node. + */ + Py_ssize_t new_count = self->a_count - 1; + + if (new_count == 0) { + return W_EMPTY; + } + + if (new_count >= 16) { + /* We convert Bitmap nodes to Array nodes, when a + Bitmap node needs to store more than 15 key/value + pairs. So we will create a new Array node if we + the number of key/values after deletion is still + greater than 15. + */ + + if (mutid != 0 && self->a_mutid == mutid) { + target = self; + Py_INCREF(self); + } + else { + target = map_node_array_clone(self, mutid); + if (target == NULL) { + return W_ERROR; + } + target->a_count = new_count; + } + + Py_CLEAR(target->a_array[idx]); + + *new_node = (MapNode*)target; /* borrow */ + return W_NEWNODE; + } + + /* New Array node would have less than 16 key/value + pairs. We need to create a replacement Bitmap node. */ + + Py_ssize_t bitmap_size = new_count * 2; + uint32_t bitmap = 0; + + MapNode_Bitmap *new = (MapNode_Bitmap *) + map_node_bitmap_new(bitmap_size, mutid); + if (new == NULL) { + return W_ERROR; + } + + Py_ssize_t new_i = 0; + for (uint32_t i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (i == idx) { + /* Skip the node we are deleting. */ + continue; + } + + MapNode *node = self->a_array[i]; + if (node == NULL) { + /* Skip any missing nodes. */ + continue; + } + + bitmap |= 1u << i; + + if (IS_BITMAP_NODE(node)) { + MapNode_Bitmap *child = (MapNode_Bitmap *)node; + + if (map_node_bitmap_count(child) == 1 && + child->b_array[0] != NULL) + { + /* node is a Bitmap with one key/value pair, just + merge it into the new Bitmap node we're building. + + Note that we don't inline Bitmap nodes that + have a NULL key -- those nodes point to another + tree level, and we cannot simply move tree levels + up or down. + */ + PyObject *key = child->b_array[0]; + PyObject *val = child->b_array[1]; + + Py_INCREF(key); + new->b_array[new_i] = key; + Py_INCREF(val); + new->b_array[new_i + 1] = val; + } + else { + new->b_array[new_i] = NULL; + Py_INCREF(node); + new->b_array[new_i + 1] = (PyObject*)node; + } + } + else { + +#ifdef NDEBUG + if (IS_COLLISION_NODE(node)) { + assert( + (map_node_collision_count( + (MapNode_Collision*)node)) > 1); + } + else if (IS_ARRAY_NODE(node)) { + assert(((MapNode_Array*)node)->a_count >= 16); + } +#endif + + /* Just copy the node into our new Bitmap */ + new->b_array[new_i] = NULL; + Py_INCREF(node); + new->b_array[new_i + 1] = (PyObject*)node; + } + + new_i += 2; + } + + new->b_bitmap = bitmap; + *new_node = (MapNode*)new; /* borrow */ + return W_NEWNODE; + } + + default: + abort(); + } +} + +static map_find_t +map_node_array_find(MapNode_Array *self, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Lookup `key` in the Array node `self`. Set the value + for the found key to 'val'. */ + + uint32_t idx = map_mask(hash, shift); + MapNode *node; + + node = self->a_array[idx]; + if (node == NULL) { + return F_NOT_FOUND; + } + + /* Dispatch to the generic map_node_find */ + return map_node_find(node, shift + 5, hash, key, val); +} + +static int +map_node_array_traverse(MapNode_Array *self, + visitproc visit, void *arg) +{ + /* Array's tp_traverse */ + + Py_ssize_t i; + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_VISIT(self->a_array[i]); + } + + return 0; +} + +static void +map_node_array_dealloc(MapNode_Array *self) +{ + /* Array's tp_dealloc */ + + Py_ssize_t i; + + PyObject_GC_UnTrack(self); + Py_TRASHCAN_SAFE_BEGIN(self) + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + Py_XDECREF(self->a_array[i]); + } + + Py_TYPE(self)->tp_free((PyObject *)self); + Py_TRASHCAN_SAFE_END(self) +} + +static int +map_node_array_dump(MapNode_Array *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for Array nodes. */ + + Py_ssize_t i; + + if (_map_dump_ident(writer, level + 1)) { + goto error; + } + + if (_map_dump_format(writer, "ArrayNode(id=%p):\n", node)) { + goto error; + } + + for (i = 0; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] == NULL) { + continue; + } + + if (_map_dump_ident(writer, level + 2)) { + goto error; + } + + if (_map_dump_format(writer, "%d::\n", i)) { + goto error; + } + + if (map_node_dump(node->a_array[i], writer, level + 1)) { + goto error; + } + + if (_map_dump_format(writer, "\n")) { + goto error; + } + } + + return 0; +error: + return -1; +} + + +/////////////////////////////////// Node Dispatch + + +static MapNode * +map_node_assoc(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject *val, int* added_leaf, + uint64_t mutid) +{ + /* Set key/value to the 'node' starting with the given shift/hash. + Return a new node, or the same node if key/value already + set. + + added_leaf will be set to 1 if key/value wasn't in the + tree before. + + This method automatically dispatches to the suitable + map_node_{nodetype}_assoc method. + */ + + *added_leaf = 0; + + if (IS_BITMAP_NODE(node)) { + return map_node_bitmap_assoc( + (MapNode_Bitmap *)node, + shift, hash, key, val, added_leaf, mutid); + } + else if (IS_ARRAY_NODE(node)) { + return map_node_array_assoc( + (MapNode_Array *)node, + shift, hash, key, val, added_leaf, mutid); + } + else { + assert(IS_COLLISION_NODE(node)); + return map_node_collision_assoc( + (MapNode_Collision *)node, + shift, hash, key, val, added_leaf, mutid); + } +} + +static map_without_t +map_node_without(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, + MapNode **new_node, + uint64_t mutid) +{ + if (IS_BITMAP_NODE(node)) { + return map_node_bitmap_without( + (MapNode_Bitmap *)node, + shift, hash, key, + new_node, + mutid); + } + else if (IS_ARRAY_NODE(node)) { + return map_node_array_without( + (MapNode_Array *)node, + shift, hash, key, + new_node, + mutid); + } + else { + assert(IS_COLLISION_NODE(node)); + return map_node_collision_without( + (MapNode_Collision *)node, + shift, hash, key, + new_node, + mutid); + } +} + +static map_find_t +map_node_find(MapNode *node, + uint32_t shift, int32_t hash, + PyObject *key, PyObject **val) +{ + /* Find the key in the node starting with the given shift/hash. + + If a value is found, the result will be set to F_FOUND, and + *val will point to the found value object. + + If a value wasn't found, the result will be set to F_NOT_FOUND. + + If an exception occurs during the call, the result will be F_ERROR. + + This method automatically dispatches to the suitable + map_node_{nodetype}_find method. + */ + + if (IS_BITMAP_NODE(node)) { + return map_node_bitmap_find( + (MapNode_Bitmap *)node, + shift, hash, key, val); + + } + else if (IS_ARRAY_NODE(node)) { + return map_node_array_find( + (MapNode_Array *)node, + shift, hash, key, val); + } + else { + assert(IS_COLLISION_NODE(node)); + return map_node_collision_find( + (MapNode_Collision *)node, + shift, hash, key, val); + } +} + +static int +map_node_dump(MapNode *node, + _PyUnicodeWriter *writer, int level) +{ + /* Debug build: __dump__() method implementation for a node. + + This method automatically dispatches to the suitable + map_node_{nodetype})_dump method. + */ + + if (IS_BITMAP_NODE(node)) { + return map_node_bitmap_dump( + (MapNode_Bitmap *)node, writer, level); + } + else if (IS_ARRAY_NODE(node)) { + return map_node_array_dump( + (MapNode_Array *)node, writer, level); + } + else { + assert(IS_COLLISION_NODE(node)); + return map_node_collision_dump( + (MapNode_Collision *)node, writer, level); + } +} + + +/////////////////////////////////// Iterators: Machinery + + +static map_iter_t +map_iterator_next(MapIteratorState *iter, PyObject **key, PyObject **val); + + +static void +map_iterator_init(MapIteratorState *iter, MapNode *root) +{ + for (uint32_t i = 0; i < _Py_HAMT_MAX_TREE_DEPTH; i++) { + iter->i_nodes[i] = NULL; + iter->i_pos[i] = 0; + } + + iter->i_level = 0; + + /* Note: we don't incref/decref nodes in i_nodes. */ + iter->i_nodes[0] = root; +} + +static map_iter_t +map_iterator_bitmap_next(MapIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + MapNode_Bitmap *node = (MapNode_Bitmap *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos + 1 >= Py_SIZE(node)) { +#ifdef NDEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return map_iterator_next(iter, key, val); + } + + if (node->b_array[pos] == NULL) { + iter->i_pos[level] = pos + 2; + + assert(level + 1 < _Py_HAMT_MAX_TREE_DEPTH); + int8_t next_level = (int8_t)(level + 1); + iter->i_level = next_level; + iter->i_pos[next_level] = 0; + iter->i_nodes[next_level] = (MapNode *) + node->b_array[pos + 1]; + + return map_iterator_next(iter, key, val); + } + + *key = node->b_array[pos]; + *val = node->b_array[pos + 1]; + iter->i_pos[level] = pos + 2; + return I_ITEM; +} + +static map_iter_t +map_iterator_collision_next(MapIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + MapNode_Collision *node = (MapNode_Collision *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos + 1 >= Py_SIZE(node)) { +#ifdef NDEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return map_iterator_next(iter, key, val); + } + + *key = node->c_array[pos]; + *val = node->c_array[pos + 1]; + iter->i_pos[level] = pos + 2; + return I_ITEM; +} + +static map_iter_t +map_iterator_array_next(MapIteratorState *iter, + PyObject **key, PyObject **val) +{ + int8_t level = iter->i_level; + + MapNode_Array *node = (MapNode_Array *)(iter->i_nodes[level]); + Py_ssize_t pos = iter->i_pos[level]; + + if (pos >= HAMT_ARRAY_NODE_SIZE) { +#ifdef NDEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + iter->i_level--; + return map_iterator_next(iter, key, val); + } + + for (Py_ssize_t i = pos; i < HAMT_ARRAY_NODE_SIZE; i++) { + if (node->a_array[i] != NULL) { + iter->i_pos[level] = i + 1; + + assert((level + 1) < _Py_HAMT_MAX_TREE_DEPTH); + int8_t next_level = (int8_t)(level + 1); + iter->i_pos[next_level] = 0; + iter->i_nodes[next_level] = node->a_array[i]; + iter->i_level = next_level; + + return map_iterator_next(iter, key, val); + } + } + +#ifdef NDEBUG + assert(iter->i_level >= 0); + iter->i_nodes[iter->i_level] = NULL; +#endif + + iter->i_level--; + return map_iterator_next(iter, key, val); +} + +static map_iter_t +map_iterator_next(MapIteratorState *iter, PyObject **key, PyObject **val) +{ + if (iter->i_level < 0) { + return I_END; + } + + assert(iter->i_level < _Py_HAMT_MAX_TREE_DEPTH); + + MapNode *current = iter->i_nodes[iter->i_level]; + + if (IS_BITMAP_NODE(current)) { + return map_iterator_bitmap_next(iter, key, val); + } + else if (IS_ARRAY_NODE(current)) { + return map_iterator_array_next(iter, key, val); + } + else { + assert(IS_COLLISION_NODE(current)); + return map_iterator_collision_next(iter, key, val); + } +} + + +/////////////////////////////////// HAMT high-level functions + + +static MapObject * +map_assoc(MapObject *o, PyObject *key, PyObject *val) +{ + int32_t key_hash; + int added_leaf = 0; + MapNode *new_root; + MapObject *new_o; + + key_hash = map_hash(key); + if (key_hash == -1) { + return NULL; + } + + new_root = map_node_assoc( + (MapNode *)(o->h_root), + 0, key_hash, key, val, &added_leaf, + 0); + if (new_root == NULL) { + return NULL; + } + + if (new_root == o->h_root) { + Py_DECREF(new_root); + Py_INCREF(o); + return o; + } + + new_o = map_alloc(); + if (new_o == NULL) { + Py_DECREF(new_root); + return NULL; + } + + new_o->h_root = new_root; /* borrow */ + new_o->h_count = added_leaf ? o->h_count + 1 : o->h_count; + + return new_o; +} + +static MapObject * +map_without(MapObject *o, PyObject *key) +{ + int32_t key_hash = map_hash(key); + if (key_hash == -1) { + return NULL; + } + + MapNode *new_root = NULL; + + map_without_t res = map_node_without( + (MapNode *)(o->h_root), + 0, key_hash, key, + &new_root, + 0); + + switch (res) { + case W_ERROR: + return NULL; + case W_EMPTY: + return map_new(); + case W_NOT_FOUND: + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + case W_NEWNODE: { + assert(new_root != NULL); + + MapObject *new_o = map_alloc(); + if (new_o == NULL) { + Py_DECREF(new_root); + return NULL; + } + + new_o->h_root = new_root; /* borrow */ + new_o->h_count = o->h_count - 1; + assert(new_o->h_count >= 0); + return new_o; + } + default: + abort(); + } +} + +static map_find_t +map_find(BaseMapObject *o, PyObject *key, PyObject **val) +{ + if (o->b_count == 0) { + return F_NOT_FOUND; + } + + int32_t key_hash = map_hash(key); + if (key_hash == -1) { + return F_ERROR; + } + + return map_node_find(o->b_root, 0, key_hash, key, val); +} + +static int +map_eq(BaseMapObject *v, BaseMapObject *w) +{ + if (v == w) { + return 1; + } + + if (v->b_count != w->b_count) { + return 0; + } + + MapIteratorState iter; + map_iter_t iter_res; + map_find_t find_res; + PyObject *v_key; + PyObject *v_val; + PyObject *w_val; + + map_iterator_init(&iter, v->b_root); + + do { + iter_res = map_iterator_next(&iter, &v_key, &v_val); + if (iter_res == I_ITEM) { + find_res = map_find(w, v_key, &w_val); + switch (find_res) { + case F_ERROR: + return -1; + + case F_NOT_FOUND: + return 0; + + case F_FOUND: { + int cmp = PyObject_RichCompareBool(v_val, w_val, Py_EQ); + if (cmp < 0) { + return -1; + } + if (cmp == 0) { + return 0; + } + } + } + } + } while (iter_res != I_END); + + return 1; +} + +static Py_ssize_t +map_len(BaseMapObject *o) +{ + return o->b_count; +} + +static MapObject * +map_alloc(void) +{ + MapObject *o; + o = PyObject_GC_New(MapObject, &_Map_Type); + if (o == NULL) { + return NULL; + } + o->h_weakreflist = NULL; + o->h_hash = -1; + o->h_count = 0; + o->h_root = NULL; + PyObject_GC_Track(o); + return o; +} + +static MapObject * +map_new(void) +{ + MapObject *o = map_alloc(); + if (o == NULL) { + return NULL; + } + + o->h_root = map_node_bitmap_new(0, 0); + if (o->h_root == NULL) { + Py_DECREF(o); + return NULL; + } + + return o; +} + +static PyObject * +map_dump(MapObject *self) +{ + _PyUnicodeWriter writer; + + _PyUnicodeWriter_Init(&writer); + + if (_map_dump_format(&writer, "HAMT(len=%zd):\n", self->h_count)) { + goto error; + } + + if (map_node_dump(self->h_root, &writer, 0)) { + goto error; + } + + return _PyUnicodeWriter_Finish(&writer); + +error: + _PyUnicodeWriter_Dealloc(&writer); + return NULL; +} + + +/////////////////////////////////// Iterators: Shared Iterator Implementation + + +static int +map_baseiter_tp_clear(MapIterator *it) +{ + Py_CLEAR(it->mi_obj); + return 0; +} + +static void +map_baseiter_tp_dealloc(MapIterator *it) +{ + PyObject_GC_UnTrack(it); + (void)map_baseiter_tp_clear(it); + PyObject_GC_Del(it); +} + +static int +map_baseiter_tp_traverse(MapIterator *it, visitproc visit, void *arg) +{ + Py_VISIT(it->mi_obj); + return 0; +} + +static PyObject * +map_baseiter_tp_iternext(MapIterator *it) +{ + PyObject *key; + PyObject *val; + map_iter_t res = map_iterator_next(&it->mi_iter, &key, &val); + + switch (res) { + case I_END: + PyErr_SetNone(PyExc_StopIteration); + return NULL; + + case I_ITEM: { + return (*(it->mi_yield))(key, val); + } + + default: { + abort(); + } + } +} + +static int +map_baseview_tp_clear(MapView *view) +{ + Py_CLEAR(view->mv_obj); + Py_CLEAR(view->mv_itertype); + return 0; +} + +static void +map_baseview_tp_dealloc(MapView *view) +{ + PyObject_GC_UnTrack(view); + (void)map_baseview_tp_clear(view); + PyObject_GC_Del(view); +} + +static int +map_baseview_tp_traverse(MapView *view, visitproc visit, void *arg) +{ + Py_VISIT(view->mv_obj); + return 0; +} + +static Py_ssize_t +map_baseview_tp_len(MapView *view) +{ + return view->mv_obj->h_count; +} + +static PyMappingMethods MapView_as_mapping = { + (lenfunc)map_baseview_tp_len, +}; + +static PyObject * +map_baseview_newiter(PyTypeObject *type, binaryfunc yield, MapObject *map) +{ + MapIterator *iter = PyObject_GC_New(MapIterator, type); + if (iter == NULL) { + return NULL; + } + + Py_INCREF(map); + iter->mi_obj = map; + iter->mi_yield = yield; + map_iterator_init(&iter->mi_iter, map->h_root); + + PyObject_GC_Track(iter); + return (PyObject *)iter; +} + +static PyObject * +map_baseview_iter(MapView *view) +{ + return map_baseview_newiter( + view->mv_itertype, view->mv_yield, view->mv_obj); +} + +static PyObject * +map_baseview_new(PyTypeObject *type, binaryfunc yield, + MapObject *o, PyTypeObject *itertype) +{ + MapView *view = PyObject_GC_New(MapView, type); + if (view == NULL) { + return NULL; + } + + Py_INCREF(o); + view->mv_obj = o; + view->mv_yield = yield; + + Py_INCREF(itertype); + view->mv_itertype = itertype; + + PyObject_GC_Track(view); + return (PyObject *)view; +} + +#define ITERATOR_TYPE_SHARED_SLOTS \ + .tp_basicsize = sizeof(MapIterator), \ + .tp_itemsize = 0, \ + .tp_dealloc = (destructor)map_baseiter_tp_dealloc, \ + .tp_getattro = PyObject_GenericGetAttr, \ + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \ + .tp_traverse = (traverseproc)map_baseiter_tp_traverse, \ + .tp_clear = (inquiry)map_baseiter_tp_clear, \ + .tp_iter = PyObject_SelfIter, \ + .tp_iternext = (iternextfunc)map_baseiter_tp_iternext, + + +#define VIEW_TYPE_SHARED_SLOTS \ + .tp_basicsize = sizeof(MapView), \ + .tp_itemsize = 0, \ + .tp_as_mapping = &MapView_as_mapping, \ + .tp_dealloc = (destructor)map_baseview_tp_dealloc, \ + .tp_getattro = PyObject_GenericGetAttr, \ + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, \ + .tp_traverse = (traverseproc)map_baseview_tp_traverse, \ + .tp_clear = (inquiry)map_baseview_tp_clear, \ + .tp_iter = (getiterfunc)map_baseview_iter, \ + + +/////////////////////////////////// _MapItems_Type + + +PyTypeObject _MapItems_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "items", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapItemsIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "items_iterator", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +map_iter_yield_items(PyObject *key, PyObject *val) +{ + return PyTuple_Pack(2, key, val); +} + +static PyObject * +map_new_items_view(MapObject *o) +{ + return map_baseview_new( + &_MapItems_Type, map_iter_yield_items, o, + &_MapItemsIter_Type); +} + + +/////////////////////////////////// _MapKeys_Type + + +PyTypeObject _MapKeys_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "keys", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapKeysIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "keys_iterator", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +map_iter_yield_keys(PyObject *key, PyObject *val) +{ + Py_INCREF(key); + return key; +} + +static PyObject * +map_new_keys_iter(MapObject *o) +{ + return map_baseview_newiter( + &_MapKeysIter_Type, map_iter_yield_keys, o); +} + +static PyObject * +map_new_keys_view(MapObject *o) +{ + return map_baseview_new( + &_MapKeys_Type, map_iter_yield_keys, o, + &_MapKeysIter_Type); +} + +/////////////////////////////////// _MapValues_Type + + +PyTypeObject _MapValues_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "values", + VIEW_TYPE_SHARED_SLOTS +}; + +PyTypeObject _MapValuesIter_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "values_iterator", + ITERATOR_TYPE_SHARED_SLOTS +}; + +static PyObject * +map_iter_yield_values(PyObject *key, PyObject *val) +{ + Py_INCREF(val); + return val; +} + +static PyObject * +map_new_values_view(MapObject *o) +{ + return map_baseview_new( + &_MapValues_Type, map_iter_yield_values, o, + &_MapValuesIter_Type); +} + + +/////////////////////////////////// _Map_Type + + +static PyObject * +map_dump(MapObject *self); + + +static PyObject * +map_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + return (PyObject*)map_new(); +} + + +static int +map_tp_init(MapObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + uint64_t mutid = 0; + + if (!PyArg_UnpackTuple(args, "immutables.Map", 0, 1, &arg)) { + return -1; + } + + if (arg != NULL) { + if (Map_Check(arg)) { + MapObject *other = (MapObject *)arg; + + Py_INCREF(other->h_root); + Py_SETREF(self->h_root, other->h_root); + + self->h_count = other->h_count; + self->h_hash = other->h_hash; + } + else if (MapMutation_Check(arg)) { + PyErr_Format( + PyExc_TypeError, + "cannot create Maps from MapMutations"); + return -1; + } + else { + mutid = mutid_counter++; + if (map_update_inplace(mutid, (BaseMapObject *)self, arg)) { + return -1; + } + } + } + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + return -1; + } + + if (!mutid) { + mutid = mutid_counter++; + } + + if (map_update_inplace(mutid, (BaseMapObject *)self, kwds)) { + return -1; + } + } + + return 0; +} + + +static int +map_tp_clear(BaseMapObject *self) +{ + Py_CLEAR(self->b_root); + return 0; +} + + +static int +map_tp_traverse(BaseMapObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->b_root); + return 0; +} + +static void +map_tp_dealloc(BaseMapObject *self) +{ + PyObject_GC_UnTrack(self); + if (self->b_weakreflist != NULL) { + PyObject_ClearWeakRefs((PyObject*)self); + } + (void)map_tp_clear(self); + Py_TYPE(self)->tp_free(self); +} + + +static PyObject * +map_tp_richcompare(PyObject *v, PyObject *w, int op) +{ + if (!Map_Check(v) || !Map_Check(w) || (op != Py_EQ && op != Py_NE)) { + Py_RETURN_NOTIMPLEMENTED; + } + + int res = map_eq((BaseMapObject *)v, (BaseMapObject *)w); + if (res < 0) { + return NULL; + } + + if (op == Py_NE) { + res = !res; + } + + if (res) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +static int +map_tp_contains(BaseMapObject *self, PyObject *key) +{ + PyObject *val; + map_find_t res = map_find(self, key, &val); + switch (res) { + case F_ERROR: + return -1; + case F_NOT_FOUND: + return 0; + case F_FOUND: + return 1; + default: + abort(); + } +} + +static PyObject * +map_tp_subscript(BaseMapObject *self, PyObject *key) +{ + PyObject *val; + map_find_t res = map_find(self, key, &val); + switch (res) { + case F_ERROR: + return NULL; + case F_FOUND: + Py_INCREF(val); + return val; + case F_NOT_FOUND: + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + default: + abort(); + } +} + +static Py_ssize_t +map_tp_len(BaseMapObject *self) +{ + return map_len(self); +} + +static PyObject * +map_tp_iter(MapObject *self) +{ + return map_new_keys_iter(self); +} + +static PyObject * +map_py_set(MapObject *self, PyObject *args) +{ + PyObject *key; + PyObject *val; + + if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) { + return NULL; + } + + return (PyObject *)map_assoc(self, key, val); +} + +static PyObject * +map_py_get(BaseMapObject *self, PyObject *args) +{ + PyObject *key; + PyObject *def = NULL; + + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &def)) { + return NULL; + } + + PyObject *val = NULL; + map_find_t res = map_find(self, key, &val); + switch (res) { + case F_ERROR: + return NULL; + case F_FOUND: + Py_INCREF(val); + return val; + case F_NOT_FOUND: + if (def == NULL) { + Py_RETURN_NONE; + } + Py_INCREF(def); + return def; + default: + abort(); + } +} + +static PyObject * +map_py_delete(MapObject *self, PyObject *key) +{ + return (PyObject *)map_without(self, key); +} + +static PyObject * +map_py_mutate(MapObject *self, PyObject *args) +{ + + MapMutationObject *o; + o = PyObject_GC_New(MapMutationObject, &_MapMutation_Type); + if (o == NULL) { + return NULL; + } + o->m_weakreflist = NULL; + o->m_count = self->h_count; + + Py_INCREF(self->h_root); + o->m_root = self->h_root; + + o->m_mutid = mutid_counter++; + + PyObject_GC_Track(o); + return (PyObject *)o; +} + +static PyObject * +map_py_update(MapObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + MapObject *new = NULL; + uint64_t mutid = 0; + + if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) { + return NULL; + } + + if (arg != NULL) { + mutid = mutid_counter++; + new = map_update(mutid, self, arg); + if (new == NULL) { + return NULL; + } + } + else { + Py_INCREF(self); + new = self; + } + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + Py_DECREF(new); + return NULL; + } + + if (!mutid) { + mutid = mutid_counter++; + } + + MapObject *new2 = map_update(mutid, new, kwds); + Py_DECREF(new); + if (new2 == NULL) { + return NULL; + } + new = new2; + } + + return (PyObject *)new; +} + +static PyObject * +map_py_items(MapObject *self, PyObject *args) +{ + return map_new_items_view(self); +} + +static PyObject * +map_py_values(MapObject *self, PyObject *args) +{ + return map_new_values_view(self); +} + +static PyObject * +map_py_keys(MapObject *self, PyObject *args) +{ + return map_new_keys_view(self); +} + +static PyObject * +map_py_dump(MapObject *self, PyObject *args) +{ + return map_dump(self); +} + + +static PyObject * +map_py_repr(BaseMapObject *m) +{ + Py_ssize_t i; + _PyUnicodeWriter writer; + + + i = Py_ReprEnter((PyObject *)m); + if (i != 0) { + return i > 0 ? PyUnicode_FromString("{...}") : NULL; + } + + _PyUnicodeWriter_Init(&writer); + + if (MapMutation_Check(m)) { + if (_PyUnicodeWriter_WriteASCIIString( + &writer, "b_root); + int second = 0; + do { + PyObject *v_key; + PyObject *v_val; + + iter_res = map_iterator_next(&iter, &v_key, &v_val); + if (iter_res == I_ITEM) { + if (second) { + if (_PyUnicodeWriter_WriteASCIIString(&writer, ", ", 2) < 0) { + goto error; + } + } + + PyObject *s = PyObject_Repr(v_key); + if (s == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, s) < 0) { + Py_DECREF(s); + goto error; + } + Py_DECREF(s); + + if (_PyUnicodeWriter_WriteASCIIString(&writer, ": ", 2) < 0) { + goto error; + } + + s = PyObject_Repr(v_val); + if (s == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, s) < 0) { + Py_DECREF(s); + goto error; + } + Py_DECREF(s); + } + + second = 1; + } while (iter_res != I_END); + + if (_PyUnicodeWriter_WriteASCIIString(&writer, "})", 2) < 0) { + goto error; + } + + PyObject *addr = PyUnicode_FromFormat(" at %p>", m); + if (addr == NULL) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(&writer, addr) < 0) { + Py_DECREF(addr); + goto error; + } + Py_DECREF(addr); + + Py_ReprLeave((PyObject *)m); + return _PyUnicodeWriter_Finish(&writer); + +error: + _PyUnicodeWriter_Dealloc(&writer); + Py_ReprLeave((PyObject *)m); + return NULL; +} + + +static Py_uhash_t +_shuffle_bits(Py_uhash_t h) +{ + return ((h ^ 89869747UL) ^ (h << 16)) * 3644798167UL; +} + + +static Py_hash_t +map_py_hash(MapObject *self) +{ + /* Adapted version of frozenset.__hash__: it's important + that Map.__hash__ is independant of key/values order. + + Optimization idea: compute and memoize intermediate + hash values for HAMT nodes. + */ + + if (self->h_hash != -1) { + return self->h_hash; + } + + Py_uhash_t hash = 0; + + MapIteratorState iter; + map_iter_t iter_res; + map_iterator_init(&iter, self->h_root); + do { + PyObject *v_key; + PyObject *v_val; + + iter_res = map_iterator_next(&iter, &v_key, &v_val); + if (iter_res == I_ITEM) { + Py_hash_t vh = PyObject_Hash(v_key); + if (vh == -1) { + return -1; + } + hash ^= _shuffle_bits((Py_uhash_t)vh); + + vh = PyObject_Hash(v_val); + if (vh == -1) { + return -1; + } + hash ^= _shuffle_bits((Py_uhash_t)vh); + } + } while (iter_res != I_END); + + hash ^= ((Py_uhash_t)self->h_count * 2 + 1) * 1927868237UL; + + hash ^= (hash >> 11) ^ (hash >> 25); + hash = hash * 69069U + 907133923UL; + + self->h_hash = (Py_hash_t)hash; + if (self->h_hash == -1) { + self->h_hash = 1; + } + return self->h_hash; +} + +static PyObject * +map_reduce(MapObject *self) +{ + MapIteratorState iter; + map_iter_t iter_res; + + PyObject *dict = PyDict_New(); + if (dict == NULL) { + return NULL; + } + + map_iterator_init(&iter, self->h_root); + do { + PyObject *key; + PyObject *val; + + iter_res = map_iterator_next(&iter, &key, &val); + if (iter_res == I_ITEM) { + if (PyDict_SetItem(dict, key, val) < 0) { + Py_DECREF(dict); + return NULL; + } + } + } while (iter_res != I_END); + + PyObject *args = PyTuple_Pack(1, dict); + Py_DECREF(dict); + if (args == NULL) { + return NULL; + } + + PyObject *tup = PyTuple_Pack(2, Py_TYPE(self), args); + Py_DECREF(args); + return tup; +} + + +static PyMethodDef Map_methods[] = { + {"set", (PyCFunction)map_py_set, METH_VARARGS, NULL}, + {"get", (PyCFunction)map_py_get, METH_VARARGS, NULL}, + {"delete", (PyCFunction)map_py_delete, METH_O, NULL}, + {"mutate", (PyCFunction)map_py_mutate, METH_NOARGS, NULL}, + {"items", (PyCFunction)map_py_items, METH_NOARGS, NULL}, + {"keys", (PyCFunction)map_py_keys, METH_NOARGS, NULL}, + {"values", (PyCFunction)map_py_values, METH_NOARGS, NULL}, + {"update", (PyCFunction)map_py_update, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__reduce__", (PyCFunction)map_reduce, METH_NOARGS, NULL}, + {"__dump__", (PyCFunction)map_py_dump, METH_NOARGS, NULL}, + {NULL, NULL} +}; + +static PySequenceMethods Map_as_sequence = { + 0, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + (objobjproc)map_tp_contains, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods Map_as_mapping = { + (lenfunc)map_tp_len, /* mp_length */ + (binaryfunc)map_tp_subscript, /* mp_subscript */ +}; + +PyTypeObject _Map_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "immutables._map.Map", + sizeof(MapObject), + .tp_methods = Map_methods, + .tp_as_mapping = &Map_as_mapping, + .tp_as_sequence = &Map_as_sequence, + .tp_iter = (getiterfunc)map_tp_iter, + .tp_dealloc = (destructor)map_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_richcompare = map_tp_richcompare, + .tp_traverse = (traverseproc)map_tp_traverse, + .tp_clear = (inquiry)map_tp_clear, + .tp_new = map_tp_new, + .tp_init = (initproc)map_tp_init, + .tp_weaklistoffset = offsetof(MapObject, h_weakreflist), + .tp_hash = (hashfunc)map_py_hash, + .tp_repr = (reprfunc)map_py_repr, +}; + + +/////////////////////////////////// MapMutation + + +static int +map_node_update_from_map(uint64_t mutid, + MapObject *map, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + assert(Map_Check(map)); + + MapIteratorState iter; + map_iter_t iter_res; + + MapNode *last_root; + Py_ssize_t last_count; + + Py_INCREF(root); + last_root = root; + last_count = count; + + map_iterator_init(&iter, map->h_root); + do { + PyObject *key; + PyObject *val; + int32_t key_hash; + int added_leaf; + + iter_res = map_iterator_next(&iter, &key, &val); + if (iter_res == I_ITEM) { + key_hash = map_hash(key); + if (key_hash == -1) { + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + } + } while (iter_res != I_END); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(last_root); + return -1; +} + + +static int +map_node_update_from_dict(uint64_t mutid, + PyObject *dct, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + assert(PyDict_Check(dct)); + + PyObject *it = PyObject_GetIter(dct); + if (it == NULL) { + return -1; + } + + MapNode *last_root; + Py_ssize_t last_count; + + Py_INCREF(root); + last_root = root; + last_count = count; + + PyObject *key; + + while ((key = PyIter_Next(it))) { + PyObject *val; + int added_leaf; + int32_t key_hash; + + key_hash = map_hash(key); + if (key_hash == -1) { + Py_DECREF(key); + goto err; + } + + val = PyDict_GetItemWithError(dct, key); + if (val == NULL) { + Py_DECREF(key); + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + Py_DECREF(key); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + } + + if (key == NULL && PyErr_Occurred()) { + goto err; + } + + Py_DECREF(it); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(it); + Py_DECREF(last_root); + return -1; +} + + +static int +map_node_update_from_seq(uint64_t mutid, + PyObject *seq, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + PyObject *it; + Py_ssize_t i; + PyObject *item = NULL; + PyObject *fast = NULL; + + MapNode *last_root; + Py_ssize_t last_count; + + it = PyObject_GetIter(seq); + if (it == NULL) { + return -1; + } + + Py_INCREF(root); + last_root = root; + last_count = count; + + for (i = 0; ; i++) { + PyObject *key, *val; + Py_ssize_t n; + int32_t key_hash; + int added_leaf; + + item = PyIter_Next(it); + if (item == NULL) { + if (PyErr_Occurred()) { + goto err; + } + break; + } + + fast = PySequence_Fast(item, ""); + if (fast == NULL) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) + PyErr_Format(PyExc_TypeError, + "cannot convert map update " + "sequence element #%zd to a sequence", + i); + goto err; + } + + n = PySequence_Fast_GET_SIZE(fast); + if (n != 2) { + PyErr_Format(PyExc_ValueError, + "map update sequence element #%zd " + "has length %zd; 2 is required", + i, n); + goto err; + } + + key = PySequence_Fast_GET_ITEM(fast, 0); + val = PySequence_Fast_GET_ITEM(fast, 1); + Py_INCREF(key); + Py_INCREF(val); + + key_hash = map_hash(key); + if (key_hash == -1) { + Py_DECREF(key); + Py_DECREF(val); + goto err; + } + + MapNode *iter_root = map_node_assoc( + last_root, + 0, key_hash, key, val, &added_leaf, + mutid); + + Py_DECREF(key); + Py_DECREF(val); + + if (iter_root == NULL) { + goto err; + } + + if (added_leaf) { + last_count++; + } + + Py_SETREF(last_root, iter_root); + + Py_DECREF(fast); + Py_DECREF(item); + } + + Py_DECREF(it); + + *new_root = last_root; + *new_count = last_count; + + return 0; + +err: + Py_DECREF(last_root); + Py_XDECREF(item); + Py_XDECREF(fast); + Py_DECREF(it); + return -1; +} + + +static int +map_node_update(uint64_t mutid, + PyObject *src, + MapNode *root, Py_ssize_t count, + MapNode **new_root, Py_ssize_t *new_count) +{ + if (Map_Check(src)) { + return map_node_update_from_map( + mutid, (MapObject *)src, root, count, new_root, new_count); + } + else if (PyDict_Check(src)) { + return map_node_update_from_dict( + mutid, src, root, count, new_root, new_count); + } + else { + return map_node_update_from_seq( + mutid, src, root, count, new_root, new_count); + } +} + + +static int +map_update_inplace(uint64_t mutid, BaseMapObject *o, PyObject *src) +{ + MapNode *new_root = NULL; + Py_ssize_t new_count; + + int ret = map_node_update( + mutid, src, + o->b_root, o->b_count, + &new_root, &new_count); + + if (ret) { + return -1; + } + + assert(new_root); + + Py_SETREF(o->b_root, new_root); + o->b_count = new_count; + + return 0; +} + + +static MapObject * +map_update(uint64_t mutid, MapObject *o, PyObject *src) +{ + MapNode *new_root = NULL; + Py_ssize_t new_count; + + int ret = map_node_update( + mutid, src, + o->h_root, o->h_count, + &new_root, &new_count); + + if (ret) { + return NULL; + } + + assert(new_root); + + MapObject *new = map_alloc(); + if (new == NULL) { + Py_DECREF(new_root); + return NULL; + } + + Py_XSETREF(new->h_root, new_root); + new->h_count = new_count; + + return new; +} + +static int +mapmut_check_finalized(MapMutationObject *o) +{ + if (o->m_mutid == 0) { + PyErr_Format( + PyExc_ValueError, + "mutation %R has been finished", + o, NULL); + return -1; + } + + return 0; +} + +static int +mapmut_delete(MapMutationObject *o, PyObject *key, int32_t key_hash) +{ + MapNode *new_root = NULL; + + assert(key_hash != -1); + map_without_t res = map_node_without( + (MapNode *)(o->m_root), + 0, key_hash, key, + &new_root, + o->m_mutid); + + switch (res) { + case W_ERROR: + return -1; + + case W_EMPTY: + new_root = map_node_bitmap_new(0, o->m_mutid); + if (new_root == NULL) { + return -1; + } + Py_SETREF(o->m_root, new_root); + o->m_count = 0; + return 0; + + case W_NOT_FOUND: + PyErr_SetObject(PyExc_KeyError, key); + return -1; + + case W_NEWNODE: { + assert(new_root != NULL); + Py_SETREF(o->m_root, new_root); + o->m_count--; + return 0; + } + + default: + abort(); + } +} + +static int +mapmut_set(MapMutationObject *o, PyObject *key, int32_t key_hash, + PyObject *val) +{ + int added_leaf = 0; + + assert(key_hash != -1); + MapNode *new_root = map_node_assoc( + (MapNode *)(o->m_root), + 0, key_hash, key, val, &added_leaf, + o->m_mutid); + if (new_root == NULL) { + return -1; + } + + if (added_leaf) { + o->m_count++; + } + + if (new_root == o->m_root) { + Py_DECREF(new_root); + return 0; + } + + Py_SETREF(o->m_root, new_root); + return 0; +} + +static int +mapmut_finish(MapMutationObject *o) +{ + o->m_mutid = 0; + return 0; +} + +static PyObject * +mapmut_py_set(MapMutationObject *o, PyObject *args) +{ + PyObject *key; + PyObject *val; + + if (!PyArg_UnpackTuple(args, "set", 2, 2, &key, &val)) { + return NULL; + } + + if (mapmut_check_finalized(o)) { + return NULL; + } + + int32_t key_hash = map_hash(key); + if (key_hash == -1) { + return NULL; + } + + if (mapmut_set(o, key, key_hash, val)) { + return NULL; + } + + Py_RETURN_NONE; +} + +static PyObject * +mapmut_tp_richcompare(PyObject *v, PyObject *w, int op) +{ + if (!MapMutation_Check(v) || !MapMutation_Check(w) || + (op != Py_EQ && op != Py_NE)) + { + Py_RETURN_NOTIMPLEMENTED; + } + + int res = map_eq((BaseMapObject *)v, (BaseMapObject *)w); + if (res < 0) { + return NULL; + } + + if (op == Py_NE) { + res = !res; + } + + if (res) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +static PyObject * +mapmut_py_update(MapMutationObject *self, PyObject *args, PyObject *kwds) +{ + PyObject *arg = NULL; + + if (!PyArg_UnpackTuple(args, "update", 0, 1, &arg)) { + return NULL; + } + + if (mapmut_check_finalized(self)) { + return NULL; + } + + if (arg != NULL) { + if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, arg)) { + return NULL; + } + } + + if (kwds != NULL) { + if (!PyArg_ValidateKeywordArguments(kwds)) { + return NULL; + } + + if (map_update_inplace(self->m_mutid, (BaseMapObject *)self, kwds)) { + return NULL; + } + } + + Py_RETURN_NONE; +} + + +static PyObject * +mapmut_py_finish(MapMutationObject *self, PyObject *args) +{ + if (mapmut_finish(self)) { + return NULL; + } + + MapObject *o = map_alloc(); + if (o == NULL) { + return NULL; + } + + Py_INCREF(self->m_root); + o->h_root = self->m_root; + o->h_count = self->m_count; + + return (PyObject *)o; +} + +static PyObject * +mapmut_py_enter(MapMutationObject *self, PyObject *args) +{ + Py_INCREF(self); + return (PyObject *)self; +} + +static PyObject * +mapmut_py_exit(MapMutationObject *self, PyObject *args) +{ + if (mapmut_finish(self)) { + return NULL; + } + Py_RETURN_FALSE; +} + +static int +mapmut_tp_ass_sub(MapMutationObject *self, PyObject *key, PyObject *val) +{ + if (mapmut_check_finalized(self)) { + return -1; + } + + int32_t key_hash = map_hash(key); + if (key_hash == -1) { + return -1; + } + + if (val == NULL) { + return mapmut_delete(self, key, key_hash); + } + else { + return mapmut_set(self, key, key_hash, val); + } +} + +static PyObject * +mapmut_py_pop(MapMutationObject *self, PyObject *args) +{ + PyObject *key, *deflt = NULL, *val = NULL; + + if(!PyArg_UnpackTuple(args, "pop", 1, 2, &key, &deflt)) { + return NULL; + } + + if (mapmut_check_finalized(self)) { + return NULL; + } + + if (!self->m_count) { + goto not_found; + } + + int32_t key_hash = map_hash(key); + if (key_hash == -1) { + return NULL; + } + + map_find_t find_res = map_node_find(self->m_root, 0, key_hash, key, &val); + + switch (find_res) { + case F_ERROR: + return NULL; + + case F_NOT_FOUND: + goto not_found; + + case F_FOUND: + break; + + default: + abort(); + } + + Py_INCREF(val); + + if (mapmut_delete(self, key, key_hash)) { + Py_DECREF(val); + return NULL; + } + + return val; + +not_found: + if (deflt) { + Py_INCREF(deflt); + return deflt; + } + + PyErr_SetObject(PyExc_KeyError, key); + return NULL; +} + + +static PyMethodDef MapMutation_methods[] = { + {"set", (PyCFunction)mapmut_py_set, METH_VARARGS, NULL}, + {"get", (PyCFunction)map_py_get, METH_VARARGS, NULL}, + {"pop", (PyCFunction)mapmut_py_pop, METH_VARARGS, NULL}, + {"finish", (PyCFunction)mapmut_py_finish, METH_NOARGS, NULL}, + {"update", (PyCFunction)mapmut_py_update, + METH_VARARGS | METH_KEYWORDS, NULL}, + {"__enter__", (PyCFunction)mapmut_py_enter, METH_NOARGS, NULL}, + {"__exit__", (PyCFunction)mapmut_py_exit, METH_VARARGS, NULL}, + {NULL, NULL} +}; + +static PySequenceMethods MapMutation_as_sequence = { + 0, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + (objobjproc)map_tp_contains, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods MapMutation_as_mapping = { + (lenfunc)map_tp_len, /* mp_length */ + (binaryfunc)map_tp_subscript, /* mp_subscript */ + (objobjargproc)mapmut_tp_ass_sub, /* mp_subscript */ +}; + +PyTypeObject _MapMutation_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "immutables._map.MapMutation", + sizeof(MapMutationObject), + .tp_methods = MapMutation_methods, + .tp_as_mapping = &MapMutation_as_mapping, + .tp_as_sequence = &MapMutation_as_sequence, + .tp_dealloc = (destructor)map_tp_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)map_tp_traverse, + .tp_richcompare = mapmut_tp_richcompare, + .tp_clear = (inquiry)map_tp_clear, + .tp_weaklistoffset = offsetof(MapMutationObject, m_weakreflist), + .tp_repr = (reprfunc)map_py_repr, + .tp_hash = PyObject_HashNotImplemented, +}; + + +/////////////////////////////////// Tree Node Types + + +PyTypeObject _Map_ArrayNode_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "map_array_node", + sizeof(MapNode_Array), + 0, + .tp_dealloc = (destructor)map_node_array_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)map_node_array_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + +PyTypeObject _Map_BitmapNode_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "map_bitmap_node", + sizeof(MapNode_Bitmap) - sizeof(PyObject *), + sizeof(PyObject *), + .tp_dealloc = (destructor)map_node_bitmap_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)map_node_bitmap_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + +PyTypeObject _Map_CollisionNode_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "map_collision_node", + sizeof(MapNode_Collision) - sizeof(PyObject *), + sizeof(PyObject *), + .tp_dealloc = (destructor)map_node_collision_dealloc, + .tp_getattro = PyObject_GenericGetAttr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = (traverseproc)map_node_collision_traverse, + .tp_free = PyObject_GC_Del, + .tp_hash = PyObject_HashNotImplemented, +}; + + +static void +module_free(void *m) +{ + Py_CLEAR(_empty_bitmap_node); +} + + +static struct PyModuleDef _mapmodule = { + PyModuleDef_HEAD_INIT, /* m_base */ + "_map", /* m_name */ + NULL, /* m_doc */ + -1, /* m_size */ + NULL, /* m_methods */ + NULL, /* m_slots */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + module_free, /* m_free */ +}; + + +PyMODINIT_FUNC +PyInit__map(void) +{ + PyObject *m = PyModule_Create(&_mapmodule); + + if ((PyType_Ready(&_Map_Type) < 0) || + (PyType_Ready(&_MapMutation_Type) < 0) || + (PyType_Ready(&_Map_ArrayNode_Type) < 0) || + (PyType_Ready(&_Map_BitmapNode_Type) < 0) || + (PyType_Ready(&_Map_CollisionNode_Type) < 0) || + (PyType_Ready(&_MapKeys_Type) < 0) || + (PyType_Ready(&_MapValues_Type) < 0) || + (PyType_Ready(&_MapItems_Type) < 0) || + (PyType_Ready(&_MapKeysIter_Type) < 0) || + (PyType_Ready(&_MapValuesIter_Type) < 0) || + (PyType_Ready(&_MapItemsIter_Type) < 0)) + { + return 0; + } + + Py_INCREF(&_Map_Type); + if (PyModule_AddObject(m, "Map", (PyObject *)&_Map_Type) < 0) { + Py_DECREF(&_Map_Type); + return NULL; + } + + return m; +} diff --git a/immutables/immutables/_map.h b/immutables/immutables/_map.h new file mode 100644 index 0000000000..dd12af9a0c --- /dev/null +++ b/immutables/immutables/_map.h @@ -0,0 +1,107 @@ +#ifndef IMMUTABLES_MAP_H +#define IMMUTABLES_MAP_H + +#include +#include "Python.h" + +#define _Py_HAMT_MAX_TREE_DEPTH 7 + + +#define Map_Check(o) (Py_TYPE(o) == &_Map_Type) +#define MapMutation_Check(o) (Py_TYPE(o) == &_MapMutation_Type) + + +/* Abstract tree node. */ +typedef struct { + PyObject_HEAD +} MapNode; + + +#define _MapCommonFields(pref) \ + PyObject_HEAD \ + MapNode *pref##_root; \ + PyObject *pref##_weakreflist; \ + Py_ssize_t pref##_count; + + +/* Base mapping struct; used in methods shared between + MapObject and MapMutationObject types. */ +typedef struct { + _MapCommonFields(b) +} BaseMapObject; + + +/* An HAMT immutable mapping collection. */ +typedef struct { + _MapCommonFields(h) + Py_hash_t h_hash; +} MapObject; + + +/* MapMutation object (returned from `map.mutate()`.) */ +typedef struct { + _MapCommonFields(m) + uint64_t m_mutid; +} MapMutationObject; + + +/* A struct to hold the state of depth-first traverse of the tree. + + HAMT is an immutable collection. Iterators will hold a strong reference + to it, and every node in the HAMT has strong references to its children. + + So for iterators, we can implement zero allocations and zero reference + inc/dec depth-first iteration. + + - i_nodes: an array of seven pointers to tree nodes + - i_level: the current node in i_nodes + - i_pos: an array of positions within nodes in i_nodes. +*/ +typedef struct { + MapNode *i_nodes[_Py_HAMT_MAX_TREE_DEPTH]; + Py_ssize_t i_pos[_Py_HAMT_MAX_TREE_DEPTH]; + int8_t i_level; +} MapIteratorState; + + +/* Base iterator object. + + Contains the iteration state, a pointer to the HAMT tree, + and a pointer to the 'yield function'. The latter is a simple + function that returns a key/value tuple for the 'Items' iterator, + just a key for the 'Keys' iterator, and a value for the 'Values' + iterator. +*/ + +typedef struct { + PyObject_HEAD + MapObject *mv_obj; + binaryfunc mv_yield; + PyTypeObject *mv_itertype; +} MapView; + +typedef struct { + PyObject_HEAD + MapObject *mi_obj; + binaryfunc mi_yield; + MapIteratorState mi_iter; +} MapIterator; + + +/* PyTypes */ + + +PyTypeObject _Map_Type; +PyTypeObject _MapMutation_Type; +PyTypeObject _Map_ArrayNode_Type; +PyTypeObject _Map_BitmapNode_Type; +PyTypeObject _Map_CollisionNode_Type; +PyTypeObject _MapKeys_Type; +PyTypeObject _MapValues_Type; +PyTypeObject _MapItems_Type; +PyTypeObject _MapKeysIter_Type; +PyTypeObject _MapValuesIter_Type; +PyTypeObject _MapItemsIter_Type; + + +#endif diff --git a/immutables/immutables/map.py b/immutables/immutables/map.py new file mode 100644 index 0000000000..cad251e6a7 --- /dev/null +++ b/immutables/immutables/map.py @@ -0,0 +1,813 @@ +import collections.abc +import itertools +import reprlib +import sys + + +__all__ = ('Map',) + + +# Thread-safe counter. +_mut_id = itertools.count(1).__next__ + + +# Python version of _map.c. The topmost comment there explains +# all datastructures and algorithms. +# The code here follows C code closely on purpose to make +# debugging and testing easier. + + +def map_hash(o): + x = hash(o) + return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff) + + +def map_mask(hash, shift): + return (hash >> shift) & 0x01f + + +def map_bitpos(hash, shift): + return 1 << map_mask(hash, shift) + + +def map_bitcount(v): + v = v - ((v >> 1) & 0x55555555) + v = (v & 0x33333333) + ((v >> 2) & 0x33333333) + v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F) + v = v + (v >> 8) + v = (v + (v >> 16)) & 0x3F + return v + + +def map_bitindex(bitmap, bit): + return map_bitcount(bitmap & (bit - 1)) + + +W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3) +void = object() + + +class BitmapNode: + + def __init__(self, size, bitmap, array, mutid): + self.size = size + self.bitmap = bitmap + assert isinstance(array, list) and len(array) == size + self.array = array + self.mutid = mutid + + def clone(self, mutid): + return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid) + + def assoc(self, shift, hash, key, val, mutid): + bit = map_bitpos(hash, shift) + idx = map_bitindex(self.bitmap, bit) + + if self.bitmap & bit: + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + sub_node, added = val_or_node.assoc( + shift + 5, hash, key, val, mutid) + if val_or_node is sub_node: + return self, added + + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return self, added + else: + ret = self.clone(mutid) + ret.array[val_idx] = sub_node + return ret, added + + if key == key_or_null: + if val is val_or_node: + return self, False + + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + ret = self.clone(mutid) + ret.array[val_idx] = val + return ret, False + + existing_key_hash = map_hash(key_or_null) + if existing_key_hash == hash: + sub_node = CollisionNode( + 4, hash, [key_or_null, val_or_node, key, val], mutid) + else: + sub_node = BitmapNode(0, 0, [], mutid) + sub_node, _ = sub_node.assoc( + shift + 5, existing_key_hash, + key_or_null, val_or_node, + mutid) + sub_node, _ = sub_node.assoc( + shift + 5, hash, key, val, + mutid) + + if mutid and mutid == self.mutid: + self.array[key_idx] = None + self.array[val_idx] = sub_node + return self, True + else: + ret = self.clone(mutid) + ret.array[key_idx] = None + ret.array[val_idx] = sub_node + return ret, True + + else: + key_idx = 2 * idx + val_idx = key_idx + 1 + + n = map_bitcount(self.bitmap) + + new_array = self.array[:key_idx] + new_array.append(key) + new_array.append(val) + new_array.extend(self.array[key_idx:]) + + if mutid and mutid == self.mutid: + self.size = 2 * (n + 1) + self.bitmap |= bit + self.array = new_array + return self, True + else: + return BitmapNode( + 2 * (n + 1), self.bitmap | bit, new_array, mutid), True + + def find(self, shift, hash, key): + bit = map_bitpos(hash, shift) + + if not (self.bitmap & bit): + raise KeyError + + idx = map_bitindex(self.bitmap, bit) + key_idx = idx * 2 + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + return val_or_node.find(shift + 5, hash, key) + + if key == key_or_null: + return val_or_node + + raise KeyError(key) + + def without(self, shift, hash, key, mutid): + bit = map_bitpos(hash, shift) + if not (self.bitmap & bit): + return W_NOT_FOUND, None + + idx = map_bitindex(self.bitmap, bit) + key_idx = 2 * idx + val_idx = key_idx + 1 + + key_or_null = self.array[key_idx] + val_or_node = self.array[val_idx] + + if key_or_null is None: + res, sub_node = val_or_node.without(shift + 5, hash, key, mutid) + + if res is W_EMPTY: + raise RuntimeError('unreachable code') # pragma: no cover + + elif res is W_NEWNODE: + if (type(sub_node) is BitmapNode and + sub_node.size == 2 and + sub_node.array[0] is not None): + + if mutid and mutid == self.mutid: + self.array[key_idx] = sub_node.array[0] + self.array[val_idx] = sub_node.array[1] + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[key_idx] = sub_node.array[0] + clone.array[val_idx] = sub_node.array[1] + return W_NEWNODE, clone + + if mutid and mutid == self.mutid: + self.array[val_idx] = sub_node + return W_NEWNODE, self + else: + clone = self.clone(mutid) + clone.array[val_idx] = sub_node + return W_NEWNODE, clone + + else: + assert sub_node is None + return res, None + + else: + if key == key_or_null: + if self.size == 2: + return W_EMPTY, None + + new_array = self.array[:key_idx] + new_array.extend(self.array[val_idx + 1:]) + + if mutid and mutid == self.mutid: + self.size -= 2 + self.bitmap &= ~bit + self.array = new_array + return W_NEWNODE, self + else: + new_node = BitmapNode( + self.size - 2, self.bitmap & ~bit, new_array, mutid) + return W_NEWNODE, new_node + + else: + return W_NOT_FOUND, None + + def keys(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + + if key_or_null is None: + val_or_node = self.array[i + 1] + yield from val_or_node.keys() + else: + yield key_or_null + + def values(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is None: + yield from val_or_node.values() + else: + yield val_or_node + + def items(self): + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + if key_or_null is None: + yield from val_or_node.items() + else: + yield key_or_null, val_or_node + + def dump(self, buf, level): # pragma: no cover + buf.append( + ' ' * (level + 1) + + 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format( + self.size, self.size / 2, bin(self.bitmap), id(self))) + + for i in range(0, self.size, 2): + key_or_null = self.array[i] + val_or_node = self.array[i + 1] + + pad = ' ' * (level + 2) + + if key_or_null is None: + buf.append(pad + 'None:') + val_or_node.dump(buf, level + 2) + else: + buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node)) + + +class CollisionNode: + + def __init__(self, size, hash, array, mutid): + self.size = size + self.hash = hash + self.array = array + self.mutid = mutid + + def find_index(self, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return i + return -1 + + def find(self, shift, hash, key): + for i in range(0, self.size, 2): + if self.array[i] == key: + return self.array[i + 1] + raise KeyError(key) + + def assoc(self, shift, hash, key, val, mutid): + if hash == self.hash: + key_idx = self.find_index(key) + + if key_idx == -1: + new_array = self.array.copy() + new_array.append(key) + new_array.append(val) + + if mutid and mutid == self.mutid: + self.size += 2 + self.array = new_array + return self, True + else: + new_node = CollisionNode( + self.size + 2, hash, new_array, mutid) + return new_node, True + + val_idx = key_idx + 1 + if self.array[val_idx] is val: + return self, False + + if mutid and mutid == self.mutid: + self.array[val_idx] = val + return self, False + else: + new_array = self.array.copy() + new_array[val_idx] = val + return CollisionNode(self.size, hash, new_array, mutid), False + + else: + new_node = BitmapNode( + 2, map_bitpos(self.hash, shift), [None, self], mutid) + return new_node.assoc(shift, hash, key, val, mutid) + + def without(self, shift, hash, key, mutid): + if hash != self.hash: + return W_NOT_FOUND, None + + key_idx = self.find_index(key) + if key_idx == -1: + return W_NOT_FOUND, None + + new_size = self.size - 2 + if new_size == 0: + # Shouldn't be ever reachable + return W_EMPTY, None # pragma: no cover + + if new_size == 2: + if key_idx == 0: + new_array = [self.array[2], self.array[3]] + else: + assert key_idx == 2 + new_array = [self.array[0], self.array[1]] + + new_node = BitmapNode( + 2, map_bitpos(hash, shift), new_array, mutid) + return W_NEWNODE, new_node + + new_array = self.array[:key_idx] + new_array.extend(self.array[key_idx + 2:]) + if mutid and mutid == self.mutid: + self.array = new_array + self.size -= 2 + return W_NEWNODE, self + else: + new_node = CollisionNode( + self.size - 2, self.hash, new_array, mutid) + return W_NEWNODE, new_node + + def keys(self): + for i in range(0, self.size, 2): + yield self.array[i] + + def values(self): + for i in range(1, self.size, 2): + yield self.array[i] + + def items(self): + for i in range(0, self.size, 2): + yield self.array[i], self.array[i + 1] + + def dump(self, buf, level): # pragma: no cover + pad = ' ' * (level + 1) + buf.append( + pad + 'CollisionNode(size={} id={:0x}):'.format( + self.size, id(self))) + + pad = ' ' * (level + 2) + for i in range(0, self.size, 2): + key = self.array[i] + val = self.array[i + 1] + + buf.append('{}{!r}: {!r}'.format(pad, key, val)) + + +class MapKeys: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.keys()) + + +class MapValues: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.values()) + + +class MapItems: + + def __init__(self, c, m): + self.__count = c + self.__root = m + + def __len__(self): + return self.__count + + def __iter__(self): + return iter(self.__root.items()) + + +class Map: + + def __init__(self, col=None, **kw): + self.__count = 0 + self.__root = BitmapNode(0, 0, [], 0) + self.__hash = -1 + + if isinstance(col, Map): + self.__count = col.__count + self.__root = col.__root + self.__hash = col.__hash + col = None + elif isinstance(col, MapMutation): + raise TypeError('cannot create Maps from MapMutations') + + if col or kw: + init = self.update(col, **kw) + self.__count = init.__count + self.__root = init.__root + + @classmethod + def _new(cls, count, root): + m = Map.__new__(Map) + m.__count = count + m.__root = root + m.__hash = -1 + return m + + def __reduce__(self): + return (type(self), (dict(self.items()),)) + + def __len__(self): + return self.__count + + def __eq__(self, other): + if not isinstance(other, Map): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + def update(self, col=None, **kw): + it = None + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + + return self + + mutid = _mut_id() + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + 'cannot convert map update ' + 'sequence element #{} to a sequence'.format(i)) from None + key, val, *r = tup + if r: + raise ValueError( + 'map update sequence element #{} has length ' + '{}; 2 is required'.format(i, len(r) + 2)) + + root, added = root.assoc(0, map_hash(key), key, val, mutid) + if added: + count += 1 + + i += 1 + + return Map._new(count, root) + + def mutate(self): + return MapMutation(self.__count, self.__root) + + def set(self, key, val): + new_count = self.__count + new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0) + + if new_root is self.__root: + assert not added + return self + + if added: + new_count += 1 + + return Map._new(new_count, new_root) + + def delete(self, key): + res, node = self.__root.without(0, map_hash(key), key, 0) + if res is W_EMPTY: + return Map() + elif res is W_NOT_FOUND: + raise KeyError(key) + else: + return Map._new(self.__count - 1, node) + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def __iter__(self): + yield from self.__root.keys() + + def keys(self): + return MapKeys(self.__count, self.__root) + + def values(self): + return MapValues(self.__count, self.__root) + + def items(self): + return MapItems(self.__count, self.__root) + + def __hash__(self): + if self.__hash != -1: + return self.__hash + + MAX = sys.maxsize + MASK = 2 * MAX + 1 + + h = 1927868237 * (self.__count * 2 + 1) + h &= MASK + + for key, value in self.__root.items(): + hx = hash(key) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + + hx = hash(value) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + + h = h * 69069 + 907133923 + h &= MASK + + if h > MAX: + h -= MASK + 1 # pragma: no cover + if h == -1: + h = 590923713 # pragma: no cover + + self.__hash = h + return h + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.items(): + items.append("{!r}: {!r}".format(key, val)) + return ''.format( + ', '.join(items), id(self)) + + def __dump__(self): # pragma: no cover + buf = [] + self.__root.dump(buf, 0) + return '\n'.join(buf) + + +class MapMutation: + + def __init__(self, count, root): + self.__count = count + self.__root = root + self.__mutid = _mut_id() + + def set(self, key, val): + self[key] = val + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.finish() + return False + + def __iter__(self): + raise TypeError('{} is not iterable'.format(type(self))) + + def __delitem__(self, key): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + res, new_root = self.__root.without( + 0, map_hash(key), key, self.__mutid) + if res is W_EMPTY: + self.__count = 0 + self.__root = BitmapNode(0, 0, [], self.__mutid) + elif res is W_NOT_FOUND: + raise KeyError(key) + else: + self.__root = new_root + self.__count -= 1 + + def __setitem__(self, key, val): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + self.__root, added = self.__root.assoc( + 0, map_hash(key), key, val, self.__mutid) + + if added: + self.__count += 1 + + def pop(self, key, *args): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + if len(args) > 1: + raise TypeError( + 'pop() accepts 1 to 2 positional arguments, ' + 'got {}'.format(len(args) + 1)) + elif len(args) == 1: + default = args[0] + else: + default = void + + val = self.get(key, default) + + try: + del self[key] + except KeyError: + if val is void: + raise + return val + else: + assert val is not void + return val + + def get(self, key, default=None): + try: + return self.__root.find(0, map_hash(key), key) + except KeyError: + return default + + def __getitem__(self, key): + return self.__root.find(0, map_hash(key), key) + + def __contains__(self, key): + try: + self.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + return True + + def update(self, col=None, **kw): + if self.__mutid == 0: + raise ValueError('mutation {!r} has been finished'.format(self)) + + it = None + if col is not None: + if hasattr(col, 'items'): + it = iter(col.items()) + else: + it = iter(col) + + if it is not None: + if kw: + it = iter(itertools.chain(it, kw.items())) + else: + if kw: + it = iter(kw.items()) + + if it is None: + + return self + + root = self.__root + count = self.__count + + i = 0 + while True: + try: + tup = next(it) + except StopIteration: + break + + try: + tup = tuple(tup) + except TypeError: + raise TypeError( + 'cannot convert map update ' + 'sequence element #{} to a sequence'.format(i)) from None + key, val, *r = tup + if r: + raise ValueError( + 'map update sequence element #{} has length ' + '{}; 2 is required'.format(i, len(r) + 2)) + + root, added = root.assoc(0, map_hash(key), key, val, self.__mutid) + if added: + count += 1 + + i += 1 + + self.__root = root + self.__count = count + + def finish(self): + self.__mutid = 0 + return Map._new(self.__count, self.__root) + + @reprlib.recursive_repr("{...}") + def __repr__(self): + items = [] + for key, val in self.__root.items(): + items.append("{!r}: {!r}".format(key, val)) + return ''.format( + ', '.join(items), id(self)) + + def __len__(self): + return self.__count + + def __reduce__(self): + raise TypeError("can't pickle {} objects".format(type(self).__name__)) + + def __hash__(self): + raise TypeError('unhashable type: {}'.format(type(self).__name__)) + + def __eq__(self, other): + if not isinstance(other, MapMutation): + return NotImplemented + + if len(self) != len(other): + return False + + for key, val in self.__root.items(): + try: + oval = other.__root.find(0, map_hash(key), key) + except KeyError: + return False + else: + if oval != val: + return False + + return True + + diff --git a/immutables/metadata.txt b/immutables/metadata.txt new file mode 100644 index 0000000000..42937eaeec --- /dev/null +++ b/immutables/metadata.txt @@ -0,0 +1,3 @@ +srctype = micropython-lib +type = package +version = 0.1 From 2a103e81f606f902a04721fdc48402c645141e7c Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:01:21 +0100 Subject: [PATCH 14/65] contextvars port --- contextvars/contextvars/__init__.py | 161 ++++++++++++++++++++++++++++ contextvars/metadata.txt | 3 + 2 files changed, 164 insertions(+) create mode 100644 contextvars/contextvars/__init__.py create mode 100644 contextvars/metadata.txt diff --git a/contextvars/contextvars/__init__.py b/contextvars/contextvars/__init__.py new file mode 100644 index 0000000000..18702dab71 --- /dev/null +++ b/contextvars/contextvars/__init__.py @@ -0,0 +1,161 @@ +import collections.abc + +import immutables + + +__all__ = ('ContextVar', 'Context', 'Token', 'copy_context') + + +_NO_DEFAULT = object() + + +class Context(collections.abc.Mapping): + + def __init__(self): + self._data = immutables.Map() + self._prev_context = None + + def run(self, callable, *args, **kwargs): + if self._prev_context is not None: + raise RuntimeError( + 'cannot enter context: {} is already entered'.format(self)) + + self._prev_context = _get_context() + try: + _set_context(self) + return callable(*args, **kwargs) + finally: + _set_context(self._prev_context) + self._prev_context = None + + def copy(self): + new = Context() + new._data = self._data + return new + + def __getitem__(self, var): + if not isinstance(var, ContextVar): + raise TypeError( + "a ContextVar key was expected, got {!r}".format(var)) + return self._data[var] + + def __contains__(self, var): + if not isinstance(var, ContextVar): + raise TypeError( + "a ContextVar key was expected, got {!r}".format(var)) + return var in self._data + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + +class ContextVar: + + def __init__(self, name, *, default=_NO_DEFAULT): + if not isinstance(name, str): + raise TypeError("context variable name must be a str") + self.name = name + self.default = default + + def get(self, default=_NO_DEFAULT): + ctx = _get_context() + try: + return ctx[self] + except KeyError: + pass + + if default is not _NO_DEFAULT: + return default + + if self.default is not _NO_DEFAULT: + return self.default + + raise LookupError + + def set(self, value): + ctx = _get_context() + data = ctx._data + try: + old_value = data[self] + except KeyError: + old_value = Token.MISSING + + updated_data = data.set(self, value) + ctx._data = updated_data + return Token(ctx, self, old_value) + + def reset(self, token): + if token._used: + raise RuntimeError("Token has already been used once") + + if token._var is not self: + raise ValueError( + "Token was created by a different ContextVar") + + if token._context is not _get_context(): + raise ValueError( + "Token was created in a different Context") + + ctx = token._context + if token._old_value is Token.MISSING: + ctx._data = ctx._data.delete(token._var) + else: + ctx._data = ctx._data.set(token._var, token._old_value) + + token._used = True + + def __repr__(self): + r = '' % id(self) + + +class Token: + + MISSING = object() + + def __init__(self, context, var, old_value): + self._context = context + self._var = var + self._old_value = old_value + self._used = False + + @property + def var(self): + return self._var + + @property + def old_value(self): + return self._old_value + + def __repr__(self): + r = ' Date: Sun, 24 Feb 2019 22:01:42 +0100 Subject: [PATCH 15/65] more stuff in collections --- collections/collections/__init__.py | 12 ++++++++++++ collections/metadata.txt | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/collections/collections/__init__.py b/collections/collections/__init__.py index afebc295fc..c566a8abd8 100644 --- a/collections/collections/__init__.py +++ b/collections/collections/__init__.py @@ -29,3 +29,15 @@ def _make(cls, seq): t = type(name, (_T,), {"_make": _make}) return t + +class Sequence: + pass + +class MutableSequence: + pass + +class Set: + pass + +class MutableSet: + pass diff --git a/collections/metadata.txt b/collections/metadata.txt index 0ada490e92..bc968b9a66 100644 --- a/collections/metadata.txt +++ b/collections/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = package -version = 0.2 +version = 0.2.1 From 46b6833ae235962831fbc266235eeee725d766b3 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:02:35 +0100 Subject: [PATCH 16/65] Make functools.partial a proper class required for compatibility --- functools/functools.py | 14 +++++++++----- functools/metadata.txt | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/functools/functools.py b/functools/functools.py index 97196bef36..55d53eb8ae 100644 --- a/functools/functools.py +++ b/functools/functools.py @@ -1,9 +1,13 @@ -def partial(func, *args, **kwargs): - def _partial(*more_args, **more_kwargs): - kw = kwargs.copy() +class partial: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def __call__(self, *more_args, **more_kwargs): + kw = self.kwargs.copy() kw.update(more_kwargs) - return func(*(args + more_args), **kw) - return _partial + return self.func(*(self.args + more_args), **kw) def update_wrapper(wrapper, wrapped, assigned=None, updated=None): diff --git a/functools/metadata.txt b/functools/metadata.txt index 7bbc045e90..7edca7ea35 100644 --- a/functools/metadata.txt +++ b/functools/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.0.7 +version = 0.1.0 From d96c01fc39bbdb4d417d55927143a737f87f7384 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:04:10 +0100 Subject: [PATCH 17/65] add contextvars tests --- contextvars/tests/__init__.py | 16 ++ contextvars/tests/test_basics.py | 353 +++++++++++++++++++++++++++++++ 2 files changed, 369 insertions(+) create mode 100644 contextvars/tests/__init__.py create mode 100644 contextvars/tests/test_basics.py diff --git a/contextvars/tests/__init__.py b/contextvars/tests/__init__.py new file mode 100644 index 0000000000..ad0ac756d0 --- /dev/null +++ b/contextvars/tests/__init__.py @@ -0,0 +1,16 @@ +import os.path +import sys +import unittest + + +def suite(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover( + os.path.dirname(__file__), pattern='test_*.py') + return test_suite + + +if __name__ == '__main__': + runner = unittest.TestRunner() + result = runner.run(suite()) + sys.exit(not result.wasSuccessful()) diff --git a/contextvars/tests/test_basics.py b/contextvars/tests/test_basics.py new file mode 100644 index 0000000000..46f9ea9a1c --- /dev/null +++ b/contextvars/tests/test_basics.py @@ -0,0 +1,353 @@ +# Tests are copied from cpython/Lib/test/test_context.py +# License: PSFL +# Copyright: 2018 Python Software Foundation + + +import concurrent.futures +import functools +import random +import time +import unittest + +import contextvars + + +def isolated_context(func): + """Needed to make reftracking test mode work.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + ctx = contextvars.Context() + return ctx.run(func, *args, **kwargs) + return wrapper + + +class ContextTest(unittest.TestCase): + def test_context_var_new_1(self): + with self.assertRaises(TypeError): + contextvars.ContextVar() + + with self.assertRaisesRegex(TypeError, 'must be a str'): + contextvars.ContextVar(1) + + c = contextvars.ContextVar('a') + self.assertNotEqual(hash(c), hash('a')) + + @isolated_context + def test_context_var_repr_1(self): + c = contextvars.ContextVar('a') + self.assertIn('a', repr(c)) + + c = contextvars.ContextVar('a', default=123) + self.assertIn('123', repr(c)) + + lst = [] + c = contextvars.ContextVar('a', default=lst) + lst.append(c) + self.assertIn('...', repr(c)) + self.assertIn('...', repr(lst)) + + t = c.set(1) + self.assertIn(repr(c), repr(t)) + self.assertNotIn(' used ', repr(t)) + c.reset(t) + self.assertIn(' used ', repr(t)) + + def test_context_subclassing_1(self): + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyContextVar(contextvars.ContextVar): + # Potentially we might want ContextVars to be subclassable. + pass + + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyContext(contextvars.Context): + pass + + with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): + class MyToken(contextvars.Token): + pass + + def test_context_new_1(self): + with self.assertRaises(TypeError): + contextvars.Context(1) + with self.assertRaises(TypeError): + contextvars.Context(1, a=1) + with self.assertRaises(TypeError): + contextvars.Context(a=1) + contextvars.Context(**{}) + + def test_context_typerrors_1(self): + ctx = contextvars.Context() + + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + ctx[1] + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + 1 in ctx + with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'): + ctx.get(1) + + def test_context_get_context_1(self): + ctx = contextvars.copy_context() + self.assertIsInstance(ctx, contextvars.Context) + + def test_context_run_1(self): + ctx = contextvars.Context() + + with self.assertRaisesRegex(TypeError, 'missing 1 required'): + ctx.run() + + def test_context_run_2(self): + ctx = contextvars.Context() + + def func(*args, **kwargs): + kwargs['spam'] = 'foo' + args += ('bar',) + return args, kwargs + + for f in (func, functools.partial(func)): + # partial doesn't support FASTCALL + + self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'})) + self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'})) + + self.assertEqual( + ctx.run(f, a=2), + (('bar',), {'a': 2, 'spam': 'foo'})) + + self.assertEqual( + ctx.run(f, 11, a=2), + ((11, 'bar'), {'a': 2, 'spam': 'foo'})) + + a = {} + self.assertEqual( + ctx.run(f, 11, **a), + ((11, 'bar'), {'spam': 'foo'})) + self.assertEqual(a, {}) + + def test_context_run_3(self): + ctx = contextvars.Context() + + def func(*args, **kwargs): + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + ctx.run(func) + with self.assertRaises(ZeroDivisionError): + ctx.run(func, 1, 2) + with self.assertRaises(ZeroDivisionError): + ctx.run(func, 1, 2, a=123) + + @isolated_context + def test_context_run_4(self): + ctx1 = contextvars.Context() + ctx2 = contextvars.Context() + var = contextvars.ContextVar('var') + + def func2(): + self.assertIsNone(var.get(None)) + + def func1(): + self.assertIsNone(var.get(None)) + var.set('spam') + ctx2.run(func2) + self.assertEqual(var.get(None), 'spam') + + cur = contextvars.copy_context() + self.assertEqual(len(cur), 1) + self.assertEqual(cur[var], 'spam') + return cur + + returned_ctx = ctx1.run(func1) + self.assertEqual(ctx1, returned_ctx) + self.assertEqual(returned_ctx[var], 'spam') + self.assertIn(var, returned_ctx) + + def test_context_run_5(self): + ctx = contextvars.Context() + var = contextvars.ContextVar('var') + + def func(): + self.assertIsNone(var.get(None)) + var.set('spam') + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + ctx.run(func) + + self.assertIsNone(var.get(None)) + + def test_context_run_6(self): + ctx = contextvars.Context() + c = contextvars.ContextVar('a', default=0) + + def fun(): + self.assertEqual(c.get(), 0) + self.assertIsNone(ctx.get(c)) + + c.set(42) + self.assertEqual(c.get(), 42) + self.assertEqual(ctx.get(c), 42) + + ctx.run(fun) + + def test_context_run_7(self): + ctx = contextvars.Context() + + def fun(): + with self.assertRaisesRegex(RuntimeError, 'is already entered'): + ctx.run(fun) + + ctx.run(fun) + + @isolated_context + def test_context_getset_1(self): + c = contextvars.ContextVar('c') + with self.assertRaises(LookupError): + c.get() + + self.assertIsNone(c.get(None)) + + t0 = c.set(42) + self.assertEqual(c.get(), 42) + self.assertEqual(c.get(None), 42) + self.assertIs(t0.old_value, t0.MISSING) + self.assertIs(t0.old_value, contextvars.Token.MISSING) + self.assertIs(t0.var, c) + + t = c.set('spam') + self.assertEqual(c.get(), 'spam') + self.assertEqual(c.get(None), 'spam') + self.assertEqual(t.old_value, 42) + c.reset(t) + + self.assertEqual(c.get(), 42) + self.assertEqual(c.get(None), 42) + + c.set('spam2') + with self.assertRaisesRegex(RuntimeError, 'has already been used'): + c.reset(t) + self.assertEqual(c.get(), 'spam2') + + ctx1 = contextvars.copy_context() + self.assertIn(c, ctx1) + + c.reset(t0) + with self.assertRaisesRegex(RuntimeError, 'has already been used'): + c.reset(t0) + self.assertIsNone(c.get(None)) + + self.assertIn(c, ctx1) + self.assertEqual(ctx1[c], 'spam2') + self.assertEqual(ctx1.get(c, 'aa'), 'spam2') + self.assertEqual(len(ctx1), 1) + self.assertEqual(list(ctx1.items()), [(c, 'spam2')]) + self.assertEqual(list(ctx1.values()), ['spam2']) + self.assertEqual(list(ctx1.keys()), [c]) + self.assertEqual(list(ctx1), [c]) + + ctx2 = contextvars.copy_context() + self.assertNotIn(c, ctx2) + with self.assertRaises(KeyError): + ctx2[c] + self.assertEqual(ctx2.get(c, 'aa'), 'aa') + self.assertEqual(len(ctx2), 0) + self.assertEqual(list(ctx2), []) + + @isolated_context + def test_context_getset_2(self): + v1 = contextvars.ContextVar('v1') + v2 = contextvars.ContextVar('v2') + + t1 = v1.set(42) + with self.assertRaisesRegex(ValueError, 'by a different'): + v2.reset(t1) + + @isolated_context + def test_context_getset_3(self): + c = contextvars.ContextVar('c', default=42) + ctx = contextvars.Context() + + def fun(): + self.assertEqual(c.get(), 42) + with self.assertRaises(KeyError): + ctx[c] + self.assertIsNone(ctx.get(c)) + self.assertEqual(ctx.get(c, 'spam'), 'spam') + self.assertNotIn(c, ctx) + self.assertEqual(list(ctx.keys()), []) + + t = c.set(1) + self.assertEqual(list(ctx.keys()), [c]) + self.assertEqual(ctx[c], 1) + + c.reset(t) + self.assertEqual(list(ctx.keys()), []) + with self.assertRaises(KeyError): + ctx[c] + + ctx.run(fun) + + @isolated_context + def test_context_getset_4(self): + c = contextvars.ContextVar('c', default=42) + ctx = contextvars.Context() + + tok = ctx.run(c.set, 1) + + with self.assertRaisesRegex(ValueError, 'different Context'): + c.reset(tok) + + @isolated_context + def test_context_getset_5(self): + c = contextvars.ContextVar('c', default=42) + c.set([]) + + def fun(): + c.set([]) + c.get().append(42) + self.assertEqual(c.get(), [42]) + + contextvars.copy_context().run(fun) + self.assertEqual(c.get(), []) + + def test_context_copy_1(self): + ctx1 = contextvars.Context() + c = contextvars.ContextVar('c', default=42) + + def ctx1_fun(): + c.set(10) + + ctx2 = ctx1.copy() + self.assertEqual(ctx2[c], 10) + + c.set(20) + self.assertEqual(ctx1[c], 20) + self.assertEqual(ctx2[c], 10) + + ctx2.run(ctx2_fun) + self.assertEqual(ctx1[c], 20) + self.assertEqual(ctx2[c], 30) + + def ctx2_fun(): + self.assertEqual(c.get(), 10) + c.set(30) + self.assertEqual(c.get(), 30) + + ctx1.run(ctx1_fun) + + @isolated_context + def test_context_threads_1(self): + cvar = contextvars.ContextVar('cvar') + + def sub(num): + for i in range(10): + cvar.set(num + i) + time.sleep(random.uniform(0.001, 0.05)) + self.assertEqual(cvar.get(), num + i) + return num + + tp = concurrent.futures.ThreadPoolExecutor(max_workers=10) + try: + results = list(tp.map(sub, range(10))) + finally: + tp.shutdown() + self.assertEqual(results, list(range(10))) From 98a7dc7219fb7ce67a0aef292538bd15e153c1cc Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:13:34 +0100 Subject: [PATCH 18/65] more operators --- operator/operator.py | 107 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/operator/operator.py b/operator/operator.py index ba59d9eda4..8f5082d5d9 100644 --- a/operator/operator.py +++ b/operator/operator.py @@ -20,7 +20,6 @@ def _methodcaller(obj): return getattr(obj, name)(*args, **kwargs) return _methodcaller - def lt(a, b): return a < b @@ -50,3 +49,109 @@ def floordiv(a, b): def getitem(a, b): return a[b] + +def add(a, b): + return a + b + +def iadd(a, b): + a += b + return a + +def sub(a, b): + return a - b + +def isub(a, b): + a -= b + return a + +def mul(a, b): + return a @ b + +def matmul(a, b): + return a * b + +def imul(a, b): + a *= b + return a + +def imatmul(a, b): + a @= b + return a + +def div(a, b): + return a / b + +def idiv(a, b): + a /= b + return a + +truediv = div +itruediv = idiv + +def floordiv(a, b): + return a // b + +def ifloordiv(a, b): + a //= b + return a + +def mod(a, b): + return a % b + +def imod(a, b): + a %= b + return a + +def pow(a, b): + return a ** b + +def ipow(a, b): + a **= b + return a + +def is_(a, b): + return a is b + +def is_not(a, b): + return a is not b + +def and_(a, b): + return a & b + +def iand(a, b): + a &= b + return a + +def or_(a, b): + return a | b + +def ior(a, b): + a |= b + return a + +def xor(a, b): + return a ^ b + +def ixor(a, b): + a ^= b + return a + +def invert(a): + return ~a + +inv = invert + +def lshift(a, b): + return a << b + +def ilshift(a, b): + a <<= b + return a + +def rshift(a, b): + return a >> b + +def irshift(a, b): + a >>= b + return a + From 38ce9179172ef7f75a3ad214aa6c635e420a1f04 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:14:44 +0100 Subject: [PATCH 19/65] random.shuffle: no-op for <2 elements --- random/metadata.txt | 2 +- random/random.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/random/metadata.txt b/random/metadata.txt index 93d12b4213..3b9b23e35a 100644 --- a/random/metadata.txt +++ b/random/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.2.2 +version = 0.2.3 diff --git a/random/random.py b/random/random.py index eddb0e680e..b6017f20f5 100644 --- a/random/random.py +++ b/random/random.py @@ -24,6 +24,8 @@ def randint(start, stop): def shuffle(seq): l = len(seq) + if l < 2: + return for i in range(l): j = randrange(l) seq[i], seq[j] = seq[j], seq[i] From 38db5d28353c145a56bcbe52c217b2ee4d9d463b Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:16:09 +0100 Subject: [PATCH 20/65] Add "uniform" as a (somewhat-broken) alias of randrange --- random/random.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/random/random.py b/random/random.py index b6017f20f5..68662e6d15 100644 --- a/random/random.py +++ b/random/random.py @@ -22,6 +22,8 @@ def randrange(start, stop=None): def randint(start, stop): return randrange(start, stop + 1) +uniform = randint + def shuffle(seq): l = len(seq) if l < 2: From 58d44171cc7e5766757631b003fccc66e48b7208 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:16:59 +0100 Subject: [PATCH 21/65] Add a random.Random() dummy object --- random/metadata.txt | 2 +- random/random.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/random/metadata.txt b/random/metadata.txt index 3b9b23e35a..e664e0316c 100644 --- a/random/metadata.txt +++ b/random/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.2.3 +version = 0.2.4 diff --git a/random/random.py b/random/random.py index 68662e6d15..8219b168cd 100644 --- a/random/random.py +++ b/random/random.py @@ -36,3 +36,24 @@ def choice(seq): if not seq: raise IndexError return seq[randrange(len(seq))] + + +class Random: + @staticmethod + def randrange(start, stop=None): + return randrange(start, stop) + + @staticmethod + def randint(start, stop): + return randint(start, stop) + + uniform = randint + + @staticmethod + def shuffle(seq): + return shuffle(seq) + + @staticmethod + def choice(seq): + return choice(seq) + From 7b0f06691df791a84c7e72837141b6af147dfbcc Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:19:55 +0100 Subject: [PATCH 22/65] reprlib (im)ported --- reprlib/metadata.txt | 4 +- reprlib/reprlib.py | 155 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 2 deletions(-) diff --git a/reprlib/metadata.txt b/reprlib/metadata.txt index dc5f60a661..e26221407a 100644 --- a/reprlib/metadata.txt +++ b/reprlib/metadata.txt @@ -1,3 +1,3 @@ -srctype = dummy +srctype = micropython-lib type = module -version = 0.0.1 +version = 0.1.0 diff --git a/reprlib/reprlib.py b/reprlib/reprlib.py index e69de29bb2..03d938342a 100644 --- a/reprlib/reprlib.py +++ b/reprlib/reprlib.py @@ -0,0 +1,155 @@ +"""Redo the builtin repr() (representation) but with limits on most sizes.""" + +__all__ = ["Repr", "repr", "recursive_repr"] + +import builtins +from itertools import islice +from _thread import get_ident + +def recursive_repr(fillvalue='...'): + 'Decorator to make a repr function return fillvalue for a recursive call' + + def decorating_function(user_function): + repr_running = set() + + def wrapper(self): + key = id(self), get_ident() + if key in repr_running: + return fillvalue + repr_running.add(key) + try: + result = user_function(self) + finally: + repr_running.discard(key) + return result + + return wrapper + + return decorating_function + +class Repr: + + def __init__(self): + self.maxlevel = 6 + self.maxtuple = 6 + self.maxlist = 6 + self.maxarray = 5 + self.maxdict = 4 + self.maxset = 6 + self.maxfrozenset = 6 + self.maxdeque = 6 + self.maxstring = 30 + self.maxlong = 40 + self.maxother = 30 + + def repr(self, x): + return self.repr1(x, self.maxlevel) + + def repr1(self, x, level): + typename = type(x).__name__ + if ' ' in typename: + parts = typename.split() + typename = '_'.join(parts) + if hasattr(self, 'repr_' + typename): + return getattr(self, 'repr_' + typename)(x, level) + else: + return self.repr_instance(x, level) + + def _repr_iterable(self, x, level, left, right, maxiter, trail=''): + n = len(x) + if level <= 0 and n: + s = '...' + else: + newlevel = level - 1 + repr1 = self.repr1 + pieces = [repr1(elem, newlevel) for elem in islice(x, maxiter)] + if n > maxiter: pieces.append('...') + s = ', '.join(pieces) + if n == 1 and trail: right = trail + right + return '%s%s%s' % (left, s, right) + + def repr_tuple(self, x, level): + return self._repr_iterable(x, level, '(', ')', self.maxtuple, ',') + + def repr_list(self, x, level): + return self._repr_iterable(x, level, '[', ']', self.maxlist) + + def repr_array(self, x, level): + if not x: + return "array('%s')" % x.typecode + header = "array('%s', [" % x.typecode + return self._repr_iterable(x, level, header, '])', self.maxarray) + + def repr_set(self, x, level): + if not x: + return 'set()' + x = _possibly_sorted(x) + return self._repr_iterable(x, level, '{', '}', self.maxset) + + def repr_frozenset(self, x, level): + if not x: + return 'frozenset()' + x = _possibly_sorted(x) + return self._repr_iterable(x, level, 'frozenset({', '})', + self.maxfrozenset) + + def repr_deque(self, x, level): + return self._repr_iterable(x, level, 'deque([', '])', self.maxdeque) + + def repr_dict(self, x, level): + n = len(x) + if n == 0: return '{}' + if level <= 0: return '{...}' + newlevel = level - 1 + repr1 = self.repr1 + pieces = [] + for key in islice(_possibly_sorted(x), self.maxdict): + keyrepr = repr1(key, newlevel) + valrepr = repr1(x[key], newlevel) + pieces.append('%s: %s' % (keyrepr, valrepr)) + if n > self.maxdict: pieces.append('...') + s = ', '.join(pieces) + return '{%s}' % (s,) + + def repr_str(self, x, level): + s = builtins.repr(x[:self.maxstring]) + if len(s) > self.maxstring: + i = max(0, (self.maxstring-3)//2) + j = max(0, self.maxstring-3-i) + s = builtins.repr(x[:i] + x[len(x)-j:]) + s = s[:i] + '...' + s[len(s)-j:] + return s + + def repr_int(self, x, level): + s = builtins.repr(x) # XXX Hope this isn't too slow... + if len(s) > self.maxlong: + i = max(0, (self.maxlong-3)//2) + j = max(0, self.maxlong-3-i) + s = s[:i] + '...' + s[len(s)-j:] + return s + + def repr_instance(self, x, level): + try: + s = builtins.repr(x) + # Bugs in x.__repr__() can cause arbitrary + # exceptions -- then make up something + except Exception: + return '<%s instance at %#x>' % (x.__class__.__name__, id(x)) + if len(s) > self.maxother: + i = max(0, (self.maxother-3)//2) + j = max(0, self.maxother-3-i) + s = s[:i] + '...' + s[len(s)-j:] + return s + + +def _possibly_sorted(x): + # Since not all sequences of items can be sorted and comparison + # functions may raise arbitrary exceptions, return an unsorted + # sequence in that case. + try: + return sorted(x) + except Exception: + return list(x) + +aRepr = Repr() +repr = aRepr.repr From 83c3124bb4782a5f53df1e0182beefcd1d8311da Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:20:51 +0100 Subject: [PATCH 23/65] signal: add a default_int_handler --- signal/signal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/signal/signal.py b/signal/signal.py index f5e1ffc0cd..2048cf180d 100644 --- a/signal/signal.py +++ b/signal/signal.py @@ -12,6 +12,8 @@ SIGCHLD = 17 SIGWINCH = 28 +default_int_handler = SIG_IGN + libc = ffilib.libc() signal_i = libc.func("i", "signal", "ii") From 0077ba005e827a6ba65d877e10ac8cc941213f5a Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 24 Feb 2019 22:21:10 +0100 Subject: [PATCH 24/65] signal: implement getsignal --- signal/metadata.txt | 2 +- signal/signal.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/signal/metadata.txt b/signal/metadata.txt index 5179c8ae4e..974698ab5b 100644 --- a/signal/metadata.txt +++ b/signal/metadata.txt @@ -1,5 +1,5 @@ srctype = pycopy-lib type = module -version = 0.4.2 +version = 0.4.3 author = Paul Sokolovsky depends = ffilib diff --git a/signal/signal.py b/signal/signal.py index 2048cf180d..669eb141bd 100644 --- a/signal/signal.py +++ b/signal/signal.py @@ -22,7 +22,6 @@ _hmap = {} - def signal(n, handler): if isinstance(handler, int): # We don't try to remove callback from _hmap here, as we return old @@ -35,3 +34,7 @@ def signal(n, handler): _hmap[n] = cb siginterrupt(n, True) return signal_p(n, cb) + +def getsignal(n): + return _sigs.get(n, SIG_DFL) + From 762222f4652884ab69904e3c1f7cc9b55e24432d Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 25 Feb 2019 20:32:00 +0100 Subject: [PATCH 25/65] comment out non-working (so far) contextvars tests --- contextvars/tests/test_basics.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/contextvars/tests/test_basics.py b/contextvars/tests/test_basics.py index 46f9ea9a1c..3381436449 100644 --- a/contextvars/tests/test_basics.py +++ b/contextvars/tests/test_basics.py @@ -40,11 +40,13 @@ def test_context_var_repr_1(self): c = contextvars.ContextVar('a', default=123) self.assertIn('123', repr(c)) - lst = [] - c = contextvars.ContextVar('a', default=lst) - lst.append(c) - self.assertIn('...', repr(c)) - self.assertIn('...', repr(lst)) + if False: # requires recursion-safe repr() + lst = [] + c = contextvars.ContextVar('a', default=lst) + lst.append(c) + self.assertIs(c.default,lst) + self.assertIn('...', repr(c)) + self.assertIn('...', repr(lst)) t = c.set(1) self.assertIn(repr(c), repr(t)) @@ -52,7 +54,7 @@ def test_context_var_repr_1(self): c.reset(t) self.assertIn(' used ', repr(t)) - def test_context_subclassing_1(self): + def _skip_test_context_subclassing_1(self): with self.assertRaisesRegex(TypeError, 'not an acceptable base type'): class MyContextVar(contextvars.ContextVar): # Potentially we might want ContextVars to be subclassable. @@ -335,7 +337,7 @@ def ctx2_fun(): ctx1.run(ctx1_fun) @isolated_context - def test_context_threads_1(self): + def _skip_test_context_threads_1(self): cvar = contextvars.ContextVar('cvar') def sub(num): From ff961c04f815ab8897482db6c0229d222bde3c04 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 25 Feb 2019 20:33:48 +0100 Subject: [PATCH 26/65] There is no __new__ --- immutables/immutables/map.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/immutables/immutables/map.py b/immutables/immutables/map.py index cad251e6a7..cfd837daca 100644 --- a/immutables/immutables/map.py +++ b/immutables/immutables/map.py @@ -453,7 +453,8 @@ def __init__(self, col=None, **kw): @classmethod def _new(cls, count, root): - m = Map.__new__(Map) + # m = Map.__new__(Map) ## no __new__ + m = Map() m.__count = count m.__root = root m.__hash = -1 From 74cae515519ecc3990011ebe43827065e697d1cf Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 25 Feb 2019 20:34:22 +0100 Subject: [PATCH 27/65] unittest: add assertNotIn --- unittest/metadata.txt | 2 +- unittest/unittest.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/unittest/metadata.txt b/unittest/metadata.txt index 3358bbff81..6e1b970dfd 100644 --- a/unittest/metadata.txt +++ b/unittest/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.5.2 +version = 0.5.3 diff --git a/unittest/unittest.py b/unittest/unittest.py index 015b8669ce..cb7e4fec7b 100644 --- a/unittest/unittest.py +++ b/unittest/unittest.py @@ -146,6 +146,11 @@ def assertIn(self, x, y, msg=''): msg = "Expected %r to be in %r" % (x, y) assert x in y, msg + def assertNotIn(self, x, y, msg=''): + if not msg: + msg = "Expected %r to be in %r" % (x, y) + assert x not in y, msg + def assertIsInstance(self, x, y, msg=''): assert isinstance(x, y), msg From a1068602b0c385503f9b49dbebaa52612c8f485b Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Fri, 31 May 2019 21:57:18 +0200 Subject: [PATCH 28/65] fix sortedcontainers --- sortedcontainers/sortedcontainers/sorteddict.py | 3 +++ sortedcontainers/sortedcontainers/sortedlist.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sortedcontainers/sortedcontainers/sorteddict.py b/sortedcontainers/sortedcontainers/sorteddict.py index d78f1099e0..4d693724a9 100644 --- a/sortedcontainers/sortedcontainers/sorteddict.py +++ b/sortedcontainers/sortedcontainers/sorteddict.py @@ -181,6 +181,9 @@ def __reversed__(self): """ return reversed(self._list) + def __getitem__(self, key): + return self._dict[key] + def __setitem__(self, key, value): """Set `d[key]` to *value*.""" if key not in self: diff --git a/sortedcontainers/sortedcontainers/sortedlist.py b/sortedcontainers/sortedcontainers/sortedlist.py index 3a87fc75f3..660a3ed29f 100644 --- a/sortedcontainers/sortedcontainers/sortedlist.py +++ b/sortedcontainers/sortedcontainers/sortedlist.py @@ -790,7 +790,7 @@ def __iter__(self): Iterating the Sequence while adding or deleting values may raise a `RuntimeError` or fail to iterate over all entries. """ - return chain(self._lists) + return chain(*self._lists) def __reversed__(self): """ From 399dc544ff3bf6c41d69042d72d05d1add4c3374 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Tue, 21 Jan 2020 22:12:42 +0100 Subject: [PATCH 29/65] no "const" on non-integers --- os/os/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/os/os/__init__.py b/os/os/__init__.py index af424410df..0dda8efd70 100644 --- a/os/os/__init__.py +++ b/os/os/__init__.py @@ -41,13 +41,13 @@ P_WAIT = 0 P_NOWAIT = 1 -error = const(OSError) +error = OSError name = "posix" -sep = const("/") -curdir = const(".") -pardir = const("..") -devnull = const("/dev/null") -linesep = const("\n") +sep = "/" +curdir = "." +pardir = ".." +devnull = "/dev/null" +linesep = "\n" if libc: From 43cd3f1f4f6dd3649cd28ba8b5e35e10a78e4df5 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:41:56 +0200 Subject: [PATCH 30/65] Add dummy AbstractContextManager class --- ucontextlib/metadata.txt | 2 +- ucontextlib/ucontextlib.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ucontextlib/metadata.txt b/ucontextlib/metadata.txt index 0a7b7526b9..4d62ebda10 100644 --- a/ucontextlib/metadata.txt +++ b/ucontextlib/metadata.txt @@ -1,5 +1,5 @@ srctype = pycopy-lib type = module -version = 0.1.2 +version = 0.1.3 license = Python long_desc = Minimal subset of contextlib for Pycopy (https://github.com/pfalcon/pycopy) low-memory ports. diff --git a/ucontextlib/ucontextlib.py b/ucontextlib/ucontextlib.py index 29445a028f..e7a02d8c1f 100644 --- a/ucontextlib/ucontextlib.py +++ b/ucontextlib/ucontextlib.py @@ -9,7 +9,11 @@ - supress """ -class ContextDecorator(object): +class AbstractContextManager(object): + "Compatibility" + pass + +class ContextDecorator(AbstractContextManager): "A base class or mixin that enables context managers to work as decorators." def _recreate_cm(self): From 178ef854bda88f79a6862fa7e03ce1fb4b8d86a7 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:19:58 +0200 Subject: [PATCH 31/65] contextlib: add async versions --- contextlib/contextlib.py | 370 ++++++++++++++++++++++++++++++++++----- contextlib/metadata.txt | 2 +- 2 files changed, 329 insertions(+), 43 deletions(-) diff --git a/contextlib/contextlib.py b/contextlib/contextlib.py index aca58d711b..e33d20fc31 100644 --- a/contextlib/contextlib.py +++ b/contextlib/contextlib.py @@ -12,6 +12,11 @@ from ucontextlib import * +class AbstractAsyncContextManager(object): + "compabilibty" + pass + + class closing(object): """Context to automatically close something at the end of a block. @@ -66,101 +71,382 @@ def __exit__(self, exctype, excinst, exctb): # See http://bugs.python.org/issue12029 for more details return exctype is not None and issubclass(exctype, self._exceptions) -# Inspired by discussions on http://bugs.python.org/issue13585 -class ExitStack(object): - """Context manager for dynamic management of a stack of exit callbacks - For example: +class _AsyncGeneratorContextManager(AbstractAsyncContextManager): + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + async def __aenter__(self): + try: + return await self.gen.__anext__() + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None - with ExitStack() as stack: - files = [stack.enter_context(open(fname)) for fname in filenames] - # All opened files will automatically be closed at the end of - # the with statement, even if attempts to open files later - # in the list raise an exception + async def __aexit__(self, typ, value, traceback): + if typ is None: + try: + await self.gen.__anext__() + except StopAsyncIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + value = typ() + # See _GeneratorContextManager.__exit__ for comments on subtleties + # in this implementation + try: + await self.gen.athrow(typ, value, traceback) + raise RuntimeError("generator didn't stop after athrow()") + except StopAsyncIteration as exc: + return exc is not value + except RuntimeError as exc: + if exc is value: + return False + # Avoid suppressing if a StopIteration exception + # was passed to throw() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if isinstance(value, (StopIteration, StopAsyncIteration)): + if exc.__cause__ is value: + return False + raise + except BaseException as exc: + if exc is not value: + raise + + +def asynccontextmanager(func): + """@asynccontextmanager decorator. + + Typical usage: + + @asynccontextmanager + async def some_async_generator(): + + try: + yield + finally: + + + This makes this: + + async with some_async_generator() as : + + equivalent to this: + + + try: + = + + finally: + """ + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + return helper + + +class _BaseExitStack: + """A base class for ExitStack and AsyncExitStack.""" + + @staticmethod + def _create_exit_wrapper(cm, cm_exit): + return MethodType(cm_exit, cm) + + @staticmethod + def _create_cb_wrapper(callback, *args, **kwds): + def _exit_wrapper(exc_type, exc, tb): + callback(*args, **kwds) + return _exit_wrapper + def __init__(self): self._exit_callbacks = deque() def pop_all(self): - """Preserve the context stack by transferring it to a new instance""" + """Preserve the context stack by transferring it to a new instance.""" new_stack = type(self)() new_stack._exit_callbacks = self._exit_callbacks self._exit_callbacks = deque() return new_stack - def _push_cm_exit(self, cm, cm_exit): - """Helper to correctly register callbacks to __exit__ methods""" - def _exit_wrapper(*exc_details): - return cm_exit(cm, *exc_details) - self.push(_exit_wrapper) - def push(self, exit): - """Registers a callback with the standard __exit__ method signature - - Can suppress exceptions the same way __exit__ methods can. + """Registers a callback with the standard __exit__ method signature. + Can suppress exceptions the same way __exit__ method can. Also accepts any object with an __exit__ method (registering a call - to the method instead of the object itself) + to the method instead of the object itself). """ # We use an unbound method rather than a bound method to follow - # the standard lookup behaviour for special methods + # the standard lookup behaviour for special methods. _cb_type = type(exit) + try: exit_method = _cb_type.__exit__ except AttributeError: - # Not a context manager, so assume its a callable - self._exit_callbacks.append(exit) + # Not a context manager, so assume it's a callable. + self._push_exit_callback(exit) else: self._push_cm_exit(exit, exit_method) - return exit # Allow use as a decorator - - def callback(self, callback, *args, **kwds): - """Registers an arbitrary callback and arguments. - - Cannot suppress exceptions. - """ - def _exit_wrapper(exc_type, exc, tb): - callback(*args, **kwds) - self.push(_exit_wrapper) - return callback # Allow use as a decorator + return exit # Allow use as a decorator. def enter_context(self, cm): - """Enters the supplied context manager + """Enters the supplied context manager. If successful, also pushes its __exit__ method as a callback and returns the result of the __enter__ method. """ - # We look up the special methods on the type to match the with statement + # We look up the special methods on the type to match the with + # statement. _cm_type = type(cm) _exit = _cm_type.__exit__ result = _cm_type.__enter__(cm) self._push_cm_exit(cm, _exit) return result - def close(self): - """Immediately unwind the context stack""" - self.__exit__(None, None, None) + def callback(self, callback, *args, **kwds): + """Registers an arbitrary callback and arguments. + + Cannot suppress exceptions. + """ + _exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds) + + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection. + _exit_wrapper.__wrapped__ = callback + self._push_exit_callback(_exit_wrapper) + return callback # Allow use as a decorator + + def _push_cm_exit(self, cm, cm_exit): + """Helper to correctly register callbacks to __exit__ methods.""" + _exit_wrapper = self._create_exit_wrapper(cm, cm_exit) + self._push_exit_callback(_exit_wrapper, True) + + def _push_exit_callback(self, callback, is_sync=True): + self._exit_callbacks.append((is_sync, callback)) + +# Inspired by discussions on http://bugs.python.org/issue13585 +class ExitStack(_BaseExitStack, AbstractContextManager): + """Context manager for dynamic management of a stack of exit callbacks. + + For example: + with ExitStack() as stack: + files = [stack.enter_context(open(fname)) for fname in filenames] + # All opened files will automatically be closed at the end of + # the with statement, even if attempts to open files later + # in the list raise an exception. + """ def __enter__(self): return self def __exit__(self, *exc_details): received_exc = exc_details[0] is not None + + # We manipulate the exception state so it behaves as though + # we were actually nesting multiple with statements + frame_exc = sys.exc_info()[1] + def _fix_exception_context(new_exc, old_exc): + # Context may not be correct, so find the end of the chain + while 1: + exc_context = new_exc.__context__ + if exc_context is old_exc: + # Context is already set correctly (see issue 20317) + return + if exc_context is None or exc_context is frame_exc: + break + new_exc = exc_context + # Change the end of the chain to point to the exception + # we expect it to reference + new_exc.__context__ = old_exc + # Callbacks are invoked in LIFO order to match the behaviour of # nested context managers suppressed_exc = False pending_raise = False while self._exit_callbacks: - cb = self._exit_callbacks.pop() + is_sync, cb = self._exit_callbacks.pop() + assert is_sync try: if cb(*exc_details): suppressed_exc = True pending_raise = False exc_details = (None, None, None) except: - exc_details = sys.exc_info() + new_exc_details = sys.exc_info() + # simulate the stack of exceptions by setting the context + _fix_exception_context(new_exc_details[1], exc_details[1]) pending_raise = True + exc_details = new_exc_details if pending_raise: - raise exc_details[1] + try: + # bare "raise exc_details[1]" replaces our carefully + # set-up context + fixed_ctx = exc_details[1].__context__ + raise exc_details[1] + except BaseException: + exc_details[1].__context__ = fixed_ctx + raise return received_exc and suppressed_exc + + def close(self): + """Immediately unwind the context stack.""" + self.__exit__(None, None, None) + +# Inspired by discussions on https://bugs.python.org/issue29302 +class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager): + """Async context manager for dynamic management of a stack of exit + callbacks. + + For example: + async with AsyncExitStack() as stack: + connections = [await stack.enter_async_context(get_connection()) + for i in range(5)] + # All opened connections will automatically be released at the + # end of the async with statement, even if attempts to open a + # connection later in the list raise an exception. + """ + + @staticmethod + def _create_async_exit_wrapper(cm, cm_exit): + return MethodType(cm_exit, cm) + + @staticmethod + def _create_async_cb_wrapper(callback, *args, **kwds): + async def _exit_wrapper(exc_type, exc, tb): + await callback(*args, **kwds) + return _exit_wrapper + + async def enter_async_context(self, cm): + """Enters the supplied async context manager. + + If successful, also pushes its __aexit__ method as a callback and + returns the result of the __aenter__ method. + """ + _cm_type = type(cm) + _exit = _cm_type.__aexit__ + result = await _cm_type.__aenter__(cm) + self._push_async_cm_exit(cm, _exit) + return result + + def push_async_exit(self, exit): + """Registers a coroutine function with the standard __aexit__ method + signature. + + Can suppress exceptions the same way __aexit__ method can. + Also accepts any object with an __aexit__ method (registering a call + to the method instead of the object itself). + """ + _cb_type = type(exit) + try: + exit_method = _cb_type.__aexit__ + except AttributeError: + # Not an async context manager, so assume it's a coroutine function + self._push_exit_callback(exit, False) + else: + self._push_async_cm_exit(exit, exit_method) + return exit # Allow use as a decorator + + def push_async_callback(self, callback, *args, **kwds): + """Registers an arbitrary coroutine function and arguments. + + Cannot suppress exceptions. + """ + _exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds) + + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection. + _exit_wrapper.__wrapped__ = callback + self._push_exit_callback(_exit_wrapper, False) + return callback # Allow use as a decorator + + async def aclose(self): + """Immediately unwind the context stack.""" + await self.__aexit__(None, None, None) + + def _push_async_cm_exit(self, cm, cm_exit): + """Helper to correctly register coroutine function to __aexit__ + method.""" + _exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit) + self._push_exit_callback(_exit_wrapper, False) + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc_details): + received_exc = exc_details[0] is not None + + # We manipulate the exception state so it behaves as though + # we were actually nesting multiple with statements + frame_exc = sys.exc_info()[1] + def _fix_exception_context(new_exc, old_exc): + # Context may not be correct, so find the end of the chain + while 1: + exc_context = new_exc.__context__ + if exc_context is old_exc: + # Context is already set correctly (see issue 20317) + return + if exc_context is None or exc_context is frame_exc: + break + new_exc = exc_context + # Change the end of the chain to point to the exception + # we expect it to reference + new_exc.__context__ = old_exc + + # Callbacks are invoked in LIFO order to match the behaviour of + # nested context managers + suppressed_exc = False + pending_raise = False + while self._exit_callbacks: + is_sync, cb = self._exit_callbacks.pop() + try: + if is_sync: + cb_suppress = cb(*exc_details) + else: + cb_suppress = await cb(*exc_details) + + if cb_suppress: + suppressed_exc = True + pending_raise = False + exc_details = (None, None, None) + except: + new_exc_details = sys.exc_info() + # simulate the stack of exceptions by setting the context + _fix_exception_context(new_exc_details[1], exc_details[1]) + pending_raise = True + exc_details = new_exc_details + if pending_raise: + try: + # bare "raise exc_details[1]" replaces our carefully + # set-up context + fixed_ctx = exc_details[1].__context__ + raise exc_details[1] + except BaseException: + exc_details[1].__context__ = fixed_ctx + raise + return received_exc and suppressed_exc + +class nullcontext(AbstractContextManager): + """Context manager that does no additional processing. + + Used as a stand-in for a normal context manager, when a particular + block of code is only sometimes used with a normal context manager: + + cm = optional_cm if condition else nullcontext() + with cm: + # Perform operation, using optional_cm if condition is True + """ + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + diff --git a/contextlib/metadata.txt b/contextlib/metadata.txt index 5f67e34825..7650e84981 100644 --- a/contextlib/metadata.txt +++ b/contextlib/metadata.txt @@ -1,4 +1,4 @@ srctype = cpython type = module -version = 3.4.2-5 +version = 3.5.0 depends = ucontextlib, collections From 4e555ccac41a01be5c05460464d514d6a924218c Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:17:04 +0200 Subject: [PATCH 32/65] Add sniffio --- sniffio/metadata.txt | 3 ++ sniffio/sniffio.py | 97 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 sniffio/metadata.txt create mode 100644 sniffio/sniffio.py diff --git a/sniffio/metadata.txt b/sniffio/metadata.txt new file mode 100644 index 0000000000..7a340ef67d --- /dev/null +++ b/sniffio/metadata.txt @@ -0,0 +1,3 @@ +srctype=async-ext +type=module +version = 0.0.2 diff --git a/sniffio/sniffio.py b/sniffio/sniffio.py new file mode 100644 index 0000000000..f036bb5818 --- /dev/null +++ b/sniffio/sniffio.py @@ -0,0 +1,97 @@ +import sys +from contextvars import ContextVar + +current_async_library_cvar = ContextVar( + "current_async_library_cvar", default=None +) + + +class AsyncLibraryNotFoundError(RuntimeError): + pass + + +def current_async_library(): + """Detect which async library is currently running. + + The following libraries are currently supported: + + ================ =========== ============================ + Library Requires Magic string + ================ =========== ============================ + **Trio** Trio v0.6+ ``"trio"`` + **Curio** - ``"curio"`` + **asyncio** ``"asyncio"`` + **uasyncio** MicroPython ``"uasyncio"`` + **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``, + depending on current mode + ================ =========== ============================ + + Returns: + A string like ``"trio"``. + + Raises: + AsyncLibraryNotFoundError: if called from synchronous context, + or if the current async library was not recognized. + + Examples: + + .. code-block:: python3 + + from sniffio import current_async_library + + async def generic_sleep(seconds): + library = current_async_library() + if library == "trio": + import trio + await trio.sleep(seconds) + elif library == "asyncio": + import asyncio + await asyncio.sleep(seconds) + # ... and so on ... + else: + raise RuntimeError(f"Unsupported library {library!r}") + + """ + value = current_async_library_cvar.get() + if value is not None: + return value + + # Sniff for curio (for now) + if 'curio' in sys.modules: + from curio.meta import curio_running + if curio_running(): + return 'curio' + + # Need to sniff for asyncio + if "asyncio" in sys.modules: + import asyncio + try: + current_task = asyncio.current_task + except AttributeError: + current_task = asyncio.Task.current_task + try: + if current_task() is not None: + if (3, 7) <= sys.version_info: + # asyncio has contextvars support, and we're in a task, so + # we can safely cache the sniffed value + current_async_library_cvar.set("asyncio") + return "asyncio" + except RuntimeError: + pass + + # Need to sniff for uasyncio + if "uasyncio" in sys.modules: + import uasyncio + current_task = uasyncio.get_event_loop().cur_task + if current_task is not None: + print("CT",current_task) + return "uasyncio" + + raise AsyncLibraryNotFoundError( + "unknown async library, or not in async context" + ) + +__all__ = [ + "current_async_library", "UnknownAsyncLibraryError", + "current_async_library_cvar" +] From 8210ce5cb04cdb3398eae8e16bfb3a9f6e97f156 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:18:01 +0200 Subject: [PATCH 33/65] add default_factory to attrs stub --- attrs/attr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/attrs/attr.py b/attrs/attr.py index 34d845b0c5..4d08c1caff 100644 --- a/attrs/attr.py +++ b/attrs/attr.py @@ -39,6 +39,8 @@ def _init(self, **kw): if isinstance(v,attrib): if hasattr(v,'factory'): v = v.factory() + elif hasattr(v,'default_factory'): + v = v.default_factory() else: v = v.default if isinstance(v,Factory): From ef19f94b7a4660f9d90e68bef5a04c493cebe0e7 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:18:15 +0200 Subject: [PATCH 34/65] Add a stub "dataclasses" which simply uses attrs, for now --- dataclasses/dataclasses.py | 4 ++++ dataclasses/metadata.txt | 3 +++ 2 files changed, 7 insertions(+) create mode 100644 dataclasses/dataclasses.py create mode 100644 dataclasses/metadata.txt diff --git a/dataclasses/dataclasses.py b/dataclasses/dataclasses.py new file mode 100644 index 0000000000..8faeaee165 --- /dev/null +++ b/dataclasses/dataclasses.py @@ -0,0 +1,4 @@ +import attr + +dataclass = attr.s +field = attr.ib diff --git a/dataclasses/metadata.txt b/dataclasses/metadata.txt new file mode 100644 index 0000000000..fda992a9c0 --- /dev/null +++ b/dataclasses/metadata.txt @@ -0,0 +1,3 @@ +srctype=dummy +type=module +version = 0.0.2 From 5e14a371b79d4117b1f26b3b1aa3e025eca47612 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:19:18 +0200 Subject: [PATCH 35/65] defaultdict: add len() and bool() --- collections.defaultdict/collections/defaultdict.py | 6 ++++++ collections.defaultdict/metadata.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/collections.defaultdict/collections/defaultdict.py b/collections.defaultdict/collections/defaultdict.py index a6ab3e2c08..12dab1e4cd 100644 --- a/collections.defaultdict/collections/defaultdict.py +++ b/collections.defaultdict/collections/defaultdict.py @@ -35,6 +35,12 @@ def __missing__(self, key): raise KeyError(key) return self.default_factory() + def __len__(self): + return len(self.d) + + def __bool__(self): + return bool(self.d) + def keys(self): return self.d.keys() diff --git a/collections.defaultdict/metadata.txt b/collections.defaultdict/metadata.txt index a7f9ce672b..dc3f3df418 100644 --- a/collections.defaultdict/metadata.txt +++ b/collections.defaultdict/metadata.txt @@ -1,4 +1,4 @@ srctype = pycopy-lib type = package -version = 0.3.2 +version = 0.3.3 author = Paul Sokolovsky From e0aaf128915ca43684e55ea07237705fee61ab0d Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:20:43 +0200 Subject: [PATCH 36/65] Workaround: attributes might not be modifiable --- cpython-pycopy/metadata.txt | 2 +- cpython-pycopy/pycopy.py | 7 ++- cpython-pycopy/pycopy_imphook.py | 87 +++++++++++++++++--------------- 3 files changed, 52 insertions(+), 44 deletions(-) diff --git a/cpython-pycopy/metadata.txt b/cpython-pycopy/metadata.txt index 09e751fff3..259fc719c0 100644 --- a/cpython-pycopy/metadata.txt +++ b/cpython-pycopy/metadata.txt @@ -1,5 +1,5 @@ srctype = cpython-backport type = module -version = 0.4.1 +version = 0.4.2 depends = cpython-uio extra_modules = pycopy_imphook diff --git a/cpython-pycopy/pycopy.py b/cpython-pycopy/pycopy.py index b00082f9b2..36c32ab170 100644 --- a/cpython-pycopy/pycopy.py +++ b/cpython-pycopy/pycopy.py @@ -24,8 +24,11 @@ def mem_free(): def mem_alloc(): return 1000000 -gc.mem_free = mem_free -gc.mem_alloc = mem_alloc +try: + gc.mem_free = mem_free + gc.mem_alloc = mem_alloc +except AttributeError: + pass import pycopy_imphook diff --git a/cpython-pycopy/pycopy_imphook.py b/cpython-pycopy/pycopy_imphook.py index c673dbce84..14af5ed066 100644 --- a/cpython-pycopy/pycopy_imphook.py +++ b/cpython-pycopy/pycopy_imphook.py @@ -8,46 +8,51 @@ _import_exts = [] -class ImphookFileLoader(importlib._bootstrap_external.FileLoader): +try: + class ImphookFileLoader(importlib._bootstrap_external.FileLoader): - def create_module(self, spec): + def create_module(self, spec): + global _import_hook + #print("create_module", spec) + basename = spec.origin.rsplit(".", 1)[0] + m = _import_hook(spec.name, basename) + return m + + def exec_module(self, mod): + # Module is fully populated in create_module + pass + +except AttributeError: + pass + +else: + + def setimphook(hook, exts): global _import_hook - #print("create_module", spec) - basename = spec.origin.rsplit(".", 1)[0] - m = _import_hook(spec.name, basename) - return m - - def exec_module(self, mod): - # Module is fully populated in create_module - pass - - -def setimphook(hook, exts): - global _import_hook - old_hook = _import_hook - _import_hook = hook - _import_exts.extend(exts) - - for i, el in enumerate(sys.path_hooks): - if not isinstance(el, type): - # Assume it's a type wrapped in a closure, - # as is the case for FileFinder. - el = type(el(".")) - if el is importlib._bootstrap_external.FileFinder: - sys.path_hooks.pop(i) - insert_pos = i - break - else: - warnings.warn("Could not find existing FileFinder to replace, installing ours as the first to use") - insert_pos = 0 - - # Mirrors what's done by importlib._bootstrap_external._install(importlib._bootstrap) - loaders = [(ImphookFileLoader, _import_exts)] + importlib._bootstrap_external._get_supported_file_loaders() - # path_hook closure captures supported_loaders in itself, all instances - # of FileFinder class will be created with it. - sys.path_hooks.insert(insert_pos, importlib._bootstrap_external.FileFinder.path_hook(*loaders)) - sys.path_importer_cache.clear() - return old_hook - - -sys.setimphook = setimphook + old_hook = _import_hook + _import_hook = hook + _import_exts.extend(exts) + + for i, el in enumerate(sys.path_hooks): + if not isinstance(el, type): + # Assume it's a type wrapped in a closure, + # as is the case for FileFinder. + el = type(el(".")) + if el is importlib._bootstrap_external.FileFinder: + sys.path_hooks.pop(i) + insert_pos = i + break + else: + warnings.warn("Could not find existing FileFinder to replace, installing ours as the first to use") + insert_pos = 0 + + # Mirrors what's done by importlib._bootstrap_external._install(importlib._bootstrap) + loaders = [(ImphookFileLoader, _import_exts)] + importlib._bootstrap_external._get_supported_file_loaders() + # path_hook closure captures supported_loaders in itself, all instances + # of FileFinder class will be created with it. + sys.path_hooks.insert(insert_pos, importlib._bootstrap_external.FileFinder.path_hook(*loaders)) + sys.path_importer_cache.clear() + return old_hook + + + sys.setimphook = setimphook From 9a9e5294fd590dc04597d402b76908ef6df28391 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:27:08 +0200 Subject: [PATCH 37/65] add ELOOP for compatiblity never raised since micropython doesn't have symlinks --- errno/errno.py | 1 + errno/metadata.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/errno/errno.py b/errno/errno.py index 7b7935ef8b..fa2bbca187 100644 --- a/errno/errno.py +++ b/errno/errno.py @@ -32,6 +32,7 @@ EPIPE = 32 # Broken pipe EDOM = 33 # Math argument out of domain of func ERANGE = 34 # Math result not representable +ELOOP = 35 # too many symlinks EAFNOSUPPORT = 97 # Address family not supported by protocol ECONNRESET = 104 # Connection timed out ETIMEDOUT = 110 # Connection timed out diff --git a/errno/metadata.txt b/errno/metadata.txt index 5fe74e5004..b158d00268 100644 --- a/errno/metadata.txt +++ b/errno/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.1.4 +version = 0.1.5 From afaae75e3ff1561631fcd9c5c345d18f77f3b07c Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:28:26 +0200 Subject: [PATCH 38/65] minimal implementation of "import_module" --- importlib/importlib.py | 30 ++++++++++++++++++++++++++++++ importlib/metadata.txt | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/importlib/importlib.py b/importlib/importlib.py index e69de29bb2..c7330f422b 100644 --- a/importlib/importlib.py +++ b/importlib/importlib.py @@ -0,0 +1,30 @@ +import imp + +def import_module(name, package=None): + """Import a module. + + The 'package' argument is required when performing a relative import. It + specifies the package to use as the anchor point from which to resolve the + relative import to an absolute import. + + """ + level = 0 + if name.startswith('.'): + if not package: + msg = ("the 'package' argument is required to perform a relative " + "import for {!r}") + raise TypeError(msg.format(name)) + for character in name: + if character != '.': + break + level += 1 + if level > 1: + name = ".".join(package.split('.')[:1-level]) + name[level-1:] + else: + name = package + name + + res = imp.org_import(name) + for n in name.split(".")[1:]: + res = getattr(res,n) + return res + diff --git a/importlib/metadata.txt b/importlib/metadata.txt index 976088c8a2..80dfba7d94 100644 --- a/importlib/metadata.txt +++ b/importlib/metadata.txt @@ -1,3 +1,3 @@ srctype = dummy type = module -version = 0.0.0 +version = 0.1.0 From 1fdb4ea8f15fa51e39f6a0a7a1cb1ea98685da4c Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:29:02 +0200 Subject: [PATCH 39/65] Alias iscoroutine[function] to isgenerator* --- inspect/inspect.py | 3 +++ inspect/metadata.txt | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/inspect/inspect.py b/inspect/inspect.py index b4a3ab8b5e..cced23288f 100644 --- a/inspect/inspect.py +++ b/inspect/inspect.py @@ -19,6 +19,9 @@ def isgeneratorfunction(obj): def isgenerator(obj): return isinstance(obj, type(lambda:(yield)())) +iscoroutinefunction = isgeneratorfunction +iscoroutine = isgenerator + class _Class: def meth(): pass _Instance = _Class() diff --git a/inspect/metadata.txt b/inspect/metadata.txt index e664e0316c..3bbc290d63 100644 --- a/inspect/metadata.txt +++ b/inspect/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.2.4 +version = 0.2.5 From 17c980aacb7f83f707219972db51ff356895649e Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:30:53 +0200 Subject: [PATCH 40/65] use ipaddress module from cpython --- ipaddress/ipaddress.py | 2314 ++++++++++++++++++++++++++++++++++++++++ ipaddress/metadata.txt | 4 +- 2 files changed, 2316 insertions(+), 2 deletions(-) diff --git a/ipaddress/ipaddress.py b/ipaddress/ipaddress.py index e69de29bb2..0aee7af2a4 100644 --- a/ipaddress/ipaddress.py +++ b/ipaddress/ipaddress.py @@ -0,0 +1,2314 @@ +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +__version__ = '1.0' + + +import functools + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % + address) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return address.to_bytes(4, 'big') + except OverflowError: + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return address.to_bytes(16, 'big') + except OverflowError: + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = str(address).split('/') + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, (~number & (number-1)).bit_length()) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if (not (isinstance(first, _BaseAddress) and + isinstance(last, _BaseAddress))): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + first, last)) + if first > last: + raise ValueError('last IP address must be greater than first') + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + (last_int - first_int + 1).bit_length() - 1) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, last.network_address <= net.network_address + # is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, nets[-1])) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase: + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = '%200s has no version specified' % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._max_prefixlen, + self._version)) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = "%r (len %d != %d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, address_len, + expected_len, self._version)) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits(ip_int, + cls._max_prefixlen) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = ip_int.to_bytes(byteslen, 'big') + msg = 'Netmask pattern %r mixes zeroes & ones' + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = '%r is not a valid netmask' % netmask_str + raise NetmaskValueError(msg) from None + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + for c in prefixlen_str: + if not ("0" <= c <= "9"): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + @classmethod + def _split_addr_prefix(cls, address): + """Helper function to parse address of Network/Interface. + + Arg: + address: Argument of Network/Interface. + + Returns: + (addr, prefix) tuple. + """ + # a packed address or integer + if isinstance(address, (bytes, int)): + return address, cls._max_prefixlen + + if not isinstance(address, tuple): + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + address = _split_optional_netmask(address) + + # Constructing from a tuple (addr, [mask]) + if len(address) > 1: + return address + return address[0], cls._max_prefixlen + + def __reduce__(self): + return self.__class__, (str(self),) + + +_address_fmt_re = None + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + return (self._ip == other._ip + and self._version == other._version) + + def __ne__(self, other): + return (self._ip != other._ip + or self._version != other._version) + + def __lt__(self, other): + if self._ip != other._ip: + return self._ip < other._ip + return False + + def __gt__(self, other): + if self._ip != other._ip: + return self._ip > other._ip + return False + + def __le__(self, other): + if self._ip == other._ip: + return True + return self._ip < other._ip + + def __ge__(self, other): + if self._ip == other._ip: + return True + return self._ip > other._ip + + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, int): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, int): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def __str__(self): + return str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + def __format__(self, fmt): + """Returns an IP address as a formatted string. + + Supported presentation types are: + 's': returns the IP address as a string (default) + 'b': converts to binary and returns a zero-padded string + 'X' or 'x': converts to upper- or lower-case hex and returns a zero-padded string + 'n': the same as 'b' for IPv4 and 'x' for IPv6 + + For binary and hex presentation types, the alternate form specifier + '#' and the grouping option '_' are supported. + """ + + # Support string formatting + if not fmt or fmt[-1] == 's': + return format(str(self), fmt) + + # From here on down, support for 'bnXx' + global _address_fmt_re + if _address_fmt_re is None: + import re + _address_fmt_re = re.compile('(#?)(_?)([xbnX])') + + m = _address_fmt_re.fullmatch(fmt) + if not m: + return super().__format__(fmt) + + alternate, grouping, fmt_base = m.groups() + + # Set some defaults + if fmt_base == 'n': + if self._version == 4: + fmt_base = 'b' # Binary is default for ipv4 + else: + fmt_base = 'x' # Hex is default for ipv6 + + if fmt_base == 'b': + padlen = self._max_prefixlen + else: + padlen = self._max_prefixlen // 4 + + if grouping: + padlen += padlen // 4 - 1 + + if alternate: + padlen += 2 # 0b or 0x + + return format(int(self), "%s0%s%s%s" % (alternate, padlen, grouping, fmt_base)) + + +class _BaseNetwork(_IPAddressBase): + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + """ + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, str(self)) + + def __str__(self): + return '%s/%d' % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError('address out of range') + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError('address out of range') + return self._address_class(broadcast + n) + + def _vcheck(self, other): + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + + def __lt__(self, other): + self._vcheck(other) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __gt__(self, other): + self._vcheck(other) + if self.network_address != other.network_address: + return self.network_address > other.network_address + if self.netmask != other.netmask: + return self.netmask > other.netmask + return False + + def __le__(self, other): + self._vcheck(other) + if self.network_address == other.network_address: + if self.netmask == other.netmask: + return True + return self.netmask < other.netmask + return self.network_address < other.network_address + + def __ge__(self, other): + self._vcheck(other) + if self.network_address == other.network_address: + if self.netmask == other.netmask: + return True + return self.netmask > other.netmask + return self.network_address > other.network_address + + def __eq__(self, other): + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + + def __ne__(self, other): + return (self._version != other._version or + self.network_address != other.network_address or + int(self.netmask) != int(other.netmask)) + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return other._ip & self.netmask._ip == self.network_address._ip + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other or ( + other.network_address in self or ( + other.broadcast_address in self))) + + @property + def broadcast_address(self): + return self._address_class(int(self.network_address) | + int(self.hostmask)) + + @property + def hostmask(self): + return self._address_class(int(self.netmask) ^ self._ALL_ONES) + + @property + def with_prefixlen(self): + return '%s/%d' % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = '%200s has no associated address class' % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + self, other)) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError('%s not contained in %s' % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__('%s/%s' % (other.network_address, + other.prefixlen)) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, self)) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return self.__class__(( + int(self.network_address) & (int(self.netmask) << prefixlen_diff), + new_prefixlen + )) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return (self.network_address.is_multicast and + self.broadcast_address.is_multicast) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError("%s and %s are not of the same version" % (a,b)) + return (b.network_address <= a.network_address and + b.broadcast_address >= a.broadcast_address) + except AttributeError: + raise TypeError("Unable to test subnet containment between %s and %s" % (a,b)) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self.network_address.is_reserved and + self.broadcast_address.is_reserved) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return (self.network_address.is_link_local and + self.broadcast_address.is_link_local) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return (self.network_address.is_private and + self.broadcast_address.is_private) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self.network_address.is_unspecified and + self.broadcast_address.is_unspecified) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self.network_address.is_loopback and + self.broadcast_address.is_loopback) + +class _BaseV4: + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2**IPV4LENGTH) - 1 + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, int): + prefixlen = arg + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen) + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return int.from_bytes(bytes(map(cls._parse_octet, octets)), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + for c in octet_str: + if not ("0" <= c <= "9"): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return '.'.join(map(str, ip_int.to_bytes(4, 'big'))) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = str(self).split('.')[::-1] + return '.'.join(reverse_octets) + '.in-addr.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, int): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + self._ip = int.from_bytes(address, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return self not in self._constants._public_network and not self.is_private + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + + def __init__(self, address): + addr, mask = self._split_addr_prefix(address) + + IPv4Address.__init__(self, addr) + self.network = IPv4Network((addr, mask), strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + + @property + def hostmask(self): + return self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if address_equal is NotImplemented or not address_equal: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return hash((self._ip, self._prefixlen, int(self.network.network_address))) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.2.0/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + """ + addr, mask = self._split_addr_prefix(address) + + self.network_address = IPv4Address(addr) + self.netmask, self._prefixlen = self._make_netmask(mask) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv4Address(packed & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + elif self._prefixlen == (self._max_prefixlen): + self.hosts = lambda: [IPv4Address(addr)] + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return (not (self.network_address in IPv4Network('100.64.0.0/10') and + self.broadcast_address in IPv4Network('100.64.0.0/10')) and + not self.is_private) + + +class _IPv4Constants: + _linklocal_network = IPv4Network('169.254.0.0/16') + + _loopback_network = IPv4Network('127.0.0.0/8') + + _multicast_network = IPv4Network('224.0.0.0/4') + + _public_network = IPv4Network('100.64.0.0/10') + + _private_networks = [ + IPv4Network('0.0.0.0/8'), + IPv4Network('10.0.0.0/8'), + IPv4Network('127.0.0.0/8'), + IPv4Network('169.254.0.0/16'), + IPv4Network('172.16.0.0/12'), + IPv4Network('192.0.0.0/29'), + IPv4Network('192.0.0.170/31'), + IPv4Network('192.0.2.0/24'), + IPv4Network('192.168.0.0/16'), + IPv4Network('198.18.0.0/15'), + IPv4Network('198.51.100.0/24'), + IPv4Network('203.0.113.0/24'), + IPv4Network('240.0.0.0/4'), + IPv4Network('255.255.255.255/32'), + ] + + _reserved_network = IPv4Network('240.0.0.0/4') + + _unspecified_address = IPv4Address('0.0.0.0') + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6: + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2**IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, int): + prefixlen = arg + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen) + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + parts = ip_str.split(':') + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if '.' in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) + parts.append('%x' % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % (_max_parts-1, ip_str) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT-1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) from None + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = str(self.ip) + else: + ip_str = str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = '%032x' % ip_int + parts = [hex_str[x:x+4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return '%s/%d' % (':'.join(parts), self._prefixlen) + return ':'.join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(':', '') + return '.'.join(reverse_chars) + '.ip6.arpa' + + @staticmethod + def _split_scope_id(ip_str): + """Helper function to parse IPv6 string address with scope id. + + See RFC 4007 for details. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + (addr, scope_id) tuple. + + """ + addr, sep, scope_id = ip_str.partition('%') + if not sep: + scope_id = None + elif not scope_id or '%' in scope_id: + raise AddressValueError('Invalid IPv6 address: "%r"' % ip_str) + return addr, scope_id + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ('_ip', '_scope_id', '__weakref__') + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, int): + self._check_int_address(address) + self._ip = address + self._scope_id = None + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + self._ip = int.from_bytes(address, 'big') + self._scope_id = None + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + addr_str, self._scope_id = self._split_scope_id(addr_str) + + self._ip = self._ip_int_from_string(addr_str) + + def __str__(self): + ip_str = super().__str__() + return ip_str + '%' + self._scope_id if self._scope_id else ip_str + + def __hash__(self): + return hash((self._ip, self._scope_id)) + + def __eq__(self, other): + address_equal = super().__eq__(other) + if address_equal is NotImplemented: + return NotImplemented + if not address_equal: + return False + return self._scope_id == getattr(other, '_scope_id', None) + + @property + def scope_id(self): + """Identifier of a particular zone of the address's scope. + + See RFC 4007 for details. + + Returns: + A string identifying the zone of the address if specified, else None. + + """ + return self._scope_id + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + + def __init__(self, address): + addr, mask = self._split_addr_prefix(address) + + IPv6Address.__init__(self, addr) + self.network = IPv6Network((addr, mask), strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + + @property + def hostmask(self): + return self.network.hostmask + + def __str__(self): + return '%s/%d' % (super().__str__(), + self._prefixlen) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if address_equal is NotImplemented or not address_equal: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return address_less + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return hash((self._ip, self._prefixlen, int(self.network.network_address))) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + """ + addr, mask = self._split_addr_prefix(address) + + self.network_address = IPv6Address(addr) + self.netmask, self._prefixlen = self._make_netmask(mask) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv6Address(packed & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + elif self._prefixlen == self._max_prefixlen: + self.hosts = lambda: [IPv6Address(addr)] + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return (self.network_address.is_site_local and + self.broadcast_address.is_site_local) + + +class _IPv6Constants: + + _linklocal_network = IPv6Network('fe80::/10') + + _multicast_network = IPv6Network('ff00::/8') + + _private_networks = [ + IPv6Network('::1/128'), + IPv6Network('::/128'), + IPv6Network('::ffff:0:0/96'), + IPv6Network('100::/64'), + IPv6Network('2001::/23'), + IPv6Network('2001:2::/48'), + IPv6Network('2001:db8::/32'), + IPv6Network('2001:10::/28'), + IPv6Network('fc00::/7'), + IPv6Network('fe80::/10'), + ] + + _reserved_networks = [ + IPv6Network('::/8'), IPv6Network('100::/8'), + IPv6Network('200::/7'), IPv6Network('400::/6'), + IPv6Network('800::/5'), IPv6Network('1000::/4'), + IPv6Network('4000::/3'), IPv6Network('6000::/3'), + IPv6Network('8000::/3'), IPv6Network('A000::/3'), + IPv6Network('C000::/3'), IPv6Network('E000::/4'), + IPv6Network('F000::/5'), IPv6Network('F800::/6'), + IPv6Network('FE00::/9'), + ] + + _sitelocal_network = IPv6Network('fec0::/10') + + +IPv6Address._constants = _IPv6Constants diff --git a/ipaddress/metadata.txt b/ipaddress/metadata.txt index dc5f60a661..cb21b842f3 100644 --- a/ipaddress/metadata.txt +++ b/ipaddress/metadata.txt @@ -1,3 +1,3 @@ -srctype = dummy +srctype = cpython type = module -version = 0.0.1 +version = 0.1.0 From 8bd5b402940b19fed633bf4b02f501d2d5217489 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:36:07 +0200 Subject: [PATCH 41/65] dummies for GenericAlias and MappingProxyType --- types/metadata.txt | 2 +- types/types.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/types/metadata.txt b/types/metadata.txt index 199c20c598..cf26433b22 100644 --- a/types/metadata.txt +++ b/types/metadata.txt @@ -1,3 +1,3 @@ srctype = cpython type = module -version = 3.3.3-3 +version = 3.3.3-4 diff --git a/types/types.py b/types/types.py index b46f46091f..ddcc64fcbb 100644 --- a/types/types.py +++ b/types/types.py @@ -16,7 +16,9 @@ def _f(): pass except: CodeType = None # TODO: Add better sentinel which can't match anything -MappingProxyType = None # TODO: Add better sentinel which can't match anything +# TODO: Add better sentinel which can't match anything +def MappingProxyType(x): + return x SimpleNamespace = None # TODO: Add better sentinel which can't match anything def _g(): @@ -34,6 +36,8 @@ def _m(self): pass BuiltinMethodType = type([].append) # Same as BuiltinFunctionType ModuleType = type(sys) +class GenericAlias: + pass try: raise TypeError From d2d278027728f70f8d0111a3e06d7227b3e72f36 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:31:53 +0200 Subject: [PATCH 42/65] add PathLike abstract class --- os.path/metadata.txt | 2 +- os.path/os/path.py | 13 +++++++++++++ os/os/__init__.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/os.path/metadata.txt b/os.path/metadata.txt index a0481bfa8f..825ebb3a1e 100644 --- a/os.path/metadata.txt +++ b/os.path/metadata.txt @@ -1,5 +1,5 @@ srctype = pycopy-lib type = package -version = 0.2.3 +version = 0.2.4 author = Paul Sokolovsky depends = ffilib diff --git a/os.path/os/path.py b/os.path/os/path.py index 06f79970a8..5276420c24 100644 --- a/os.path/os/path.py +++ b/os.path/os/path.py @@ -1,5 +1,6 @@ import os import ffilib +from types import GenericAlias sep = "/" @@ -96,3 +97,15 @@ def expanduser(s): # Sorry folks, follow conventions return "/home/" + s[1:] return s + + +class PathLike: + + """Abstract base class for implementing the file system path protocol.""" + + def __fspath__(self): + """Return the file system path representation of the object.""" + raise NotImplementedError + + __class_getitem__ = classmethod(GenericAlias) + diff --git a/os/os/__init__.py b/os/os/__init__.py index 0dda8efd70..43c1de02b3 100644 --- a/os/os/__init__.py +++ b/os/os/__init__.py @@ -37,6 +37,7 @@ from uos2 import * from uos2 import _exit, _libc as libc +PathLike = path.PathLike P_WAIT = 0 P_NOWAIT = 1 From c71dcbb82959cf47773c4f432c90b2f9a887dd80 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:32:31 +0200 Subject: [PATCH 43/65] os.env: don't subclass dict --- os/metadata.txt | 2 +- os/os/__init__.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/os/metadata.txt b/os/metadata.txt index 8071d1c42a..575cd1937a 100644 --- a/os/metadata.txt +++ b/os/metadata.txt @@ -1,5 +1,5 @@ srctype = pycopy-lib type = package -version = 1.1.1 +version = 1.1.2 author = Paul Sokolovsky depends = errno, stat, os.path, uos2 diff --git a/os/os/__init__.py b/os/os/__init__.py index 43c1de02b3..3f972c457a 100644 --- a/os/os/__init__.py +++ b/os/os/__init__.py @@ -260,21 +260,27 @@ def closerange(low, high): } -class _Environ(dict): +class _Environ(object): def __init__(self): - dict.__init__(self) + self._data = dict() env = uctypes.struct(_environ_ptr.get(), _ENV_STRUCT) for i in range(4096): if int(env.arr[i]) == 0: break - s = uctypes.bytes_at(env.arr[i]).decode() + # requires micropython change f20a730 + s = uctypes.bytestring_at(int(env.arr[i])).decode() k, v = s.split("=", 1) - dict.__setitem__(self, k, v) + self._data[k] = v + self.__getitem__ = self._data.__getitem__ def __setitem__(self, k, v): - putenv(k, v) - dict.__setitem__(self, k, v) + try: + uos2.putenv(k.encode(), v.encode()) + except AttributeError: + # XXX is this right? + pass + self._data[k] = v environ = _Environ() From b2750ecf6100369d6665333fb27e327a4cfa2012 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:33:24 +0200 Subject: [PATCH 44/65] pathlib: copy from cpython 3.9 --- pathlib/metadata.txt | 4 +- pathlib/pathlib.py | 1596 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1598 insertions(+), 2 deletions(-) diff --git a/pathlib/metadata.txt b/pathlib/metadata.txt index abee00b449..e5c7637f4f 100644 --- a/pathlib/metadata.txt +++ b/pathlib/metadata.txt @@ -1,3 +1,3 @@ -srctype=dummy +srctype=cpython type=module -version = 0.0.1 +version = 0.1.0 diff --git a/pathlib/pathlib.py b/pathlib/pathlib.py index e69de29bb2..1f2ccfa9d9 100644 --- a/pathlib/pathlib.py +++ b/pathlib/pathlib.py @@ -0,0 +1,1596 @@ +import fnmatch +import functools +import io +try: + import ntpath +except ImportError: + ntpath = None +import os +import posixpath +import re +import sys +from collections.abc import Sequence +from errno import EINVAL, ENOENT, ENOTDIR, EBADF, ELOOP +from operator import attrgetter +from stat import S_ISDIR, S_ISLNK, S_ISREG, S_ISSOCK, S_ISBLK, S_ISCHR, S_ISFIFO +from urllib.parse import quote_from_bytes as urlquote_from_bytes + + +supports_symlinks = True +if os.name == 'nt': + import nt + if sys.getwindowsversion()[:2] >= (6, 0): + from nt import _getfinalpathname + else: + supports_symlinks = False + _getfinalpathname = None +else: + nt = None + + +__all__ = [ + "PurePath", "PurePosixPath", "PureWindowsPath", + "Path", "PosixPath", "WindowsPath", + ] + +# +# Internals +# + +# EBADF - guard against macOS `stat` throwing EBADF +_IGNORED_ERROS = (ENOENT, ENOTDIR, EBADF, ELOOP) + +_IGNORED_WINERRORS = ( + 21, # ERROR_NOT_READY - drive exists but is not accessible + 1921, # ERROR_CANT_RESOLVE_FILENAME - fix for broken symlink pointing to itself +) + +def _ignore_error(exception): + return (getattr(exception, 'errno', None) in _IGNORED_ERROS or + getattr(exception, 'winerror', None) in _IGNORED_WINERRORS) + + +def _is_wildcard_pattern(pat): + # Whether this pattern needs actual matching using fnmatch, or can + # be looked up directly as a file. + return "*" in pat or "?" in pat or "[" in pat + + +class _Flavour(object): + """A flavour implements a particular (platform-specific) set of path + semantics.""" + + def __init__(self): + self.join = self.sep.join + + def parse_parts(self, parts): + parsed = [] + sep = self.sep + altsep = self.altsep + drv = root = '' + it = reversed(parts) + for part in it: + if not part: + continue + if altsep: + part = part.replace(altsep, sep) + drv, root, rel = self.splitroot(part) + if sep in rel: + for x in reversed(rel.split(sep)): + if x and x != '.': + parsed.append(sys.intern(x)) + else: + if rel and rel != '.': + parsed.append(sys.intern(rel)) + if drv or root: + if not drv: + # If no drive is present, try to find one in the previous + # parts. This makes the result of parsing e.g. + # ("C:", "/", "a") reasonably intuitive. + for part in it: + if not part: + continue + if altsep: + part = part.replace(altsep, sep) + drv = self.splitroot(part)[0] + if drv: + break + break + if drv or root: + parsed.append(drv + root) + parsed.reverse() + return drv, root, parsed + + def join_parsed_parts(self, drv, root, parts, drv2, root2, parts2): + """ + Join the two paths represented by the respective + (drive, root, parts) tuples. Return a new (drive, root, parts) tuple. + """ + if root2: + if not drv2 and drv: + return drv, root2, [drv + root2] + parts2[1:] + elif drv2: + if drv2 == drv or self.casefold(drv2) == self.casefold(drv): + # Same drive => second path is relative to the first + return drv, root, parts + parts2[1:] + else: + # Second path is non-anchored (common case) + return drv, root, parts + parts2 + return drv2, root2, parts2 + + +class _WindowsFlavour(_Flavour): + # Reference for Windows paths can be found at + # http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx + + sep = '\\' + altsep = '/' + has_drv = True + pathmod = ntpath + + is_supported = (os.name == 'nt') + + drive_letters = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + ext_namespace_prefix = '\\\\?\\' + + reserved_names = ( + {'CON', 'PRN', 'AUX', 'NUL'} | + {'COM%d' % i for i in range(1, 10)} | + {'LPT%d' % i for i in range(1, 10)} + ) + + # Interesting findings about extended paths: + # - '\\?\c:\a', '//?/c:\a' and '//?/c:/a' are all supported + # but '\\?\c:/a' is not + # - extended paths are always absolute; "relative" extended paths will + # fail. + + def splitroot(self, part, sep=sep): + first = part[0:1] + second = part[1:2] + if (second == sep and first == sep): + # XXX extended paths should also disable the collapsing of "." + # components (according to MSDN docs). + prefix, part = self._split_extended_path(part) + first = part[0:1] + second = part[1:2] + else: + prefix = '' + third = part[2:3] + if (second == sep and first == sep and third != sep): + # is a UNC path: + # vvvvvvvvvvvvvvvvvvvvv root + # \\machine\mountpoint\directory\etc\... + # directory ^^^^^^^^^^^^^^ + index = part.find(sep, 2) + if index != -1: + index2 = part.find(sep, index + 1) + # a UNC path can't have two slashes in a row + # (after the initial two) + if index2 != index + 1: + if index2 == -1: + index2 = len(part) + if prefix: + return prefix + part[1:index2], sep, part[index2+1:] + else: + return part[:index2], sep, part[index2+1:] + drv = root = '' + if second == ':' and first in self.drive_letters: + drv = part[:2] + part = part[2:] + first = third + if first == sep: + root = first + part = part.lstrip(sep) + return prefix + drv, root, part + + def casefold(self, s): + return s.lower() + + def casefold_parts(self, parts): + return [p.lower() for p in parts] + + def compile_pattern(self, pattern): + return re.compile(fnmatch.translate(pattern), re.IGNORECASE).fullmatch + + def resolve(self, path, strict=False): + s = str(path) + if not s: + return os.getcwd() + previous_s = None + if _getfinalpathname is not None: + if strict: + return self._ext_to_normal(_getfinalpathname(s)) + else: + tail_parts = [] # End of the path after the first one not found + while True: + try: + s = self._ext_to_normal(_getfinalpathname(s)) + except FileNotFoundError: + previous_s = s + s, tail = os.path.split(s) + tail_parts.append(tail) + if previous_s == s: + return path + else: + return os.path.join(s, *reversed(tail_parts)) + # Means fallback on absolute + return None + + def _split_extended_path(self, s, ext_prefix=ext_namespace_prefix): + prefix = '' + if s.startswith(ext_prefix): + prefix = s[:4] + s = s[4:] + if s.startswith('UNC\\'): + prefix += s[:3] + s = '\\' + s[3:] + return prefix, s + + def _ext_to_normal(self, s): + # Turn back an extended path into a normal DOS-like path + return self._split_extended_path(s)[1] + + def is_reserved(self, parts): + # NOTE: the rules for reserved names seem somewhat complicated + # (e.g. r"..\NUL" is reserved but not r"foo\NUL"). + # We err on the side of caution and return True for paths which are + # not considered reserved by Windows. + if not parts: + return False + if parts[0].startswith('\\\\'): + # UNC paths are never reserved + return False + return parts[-1].partition('.')[0].upper() in self.reserved_names + + def make_uri(self, path): + # Under Windows, file URIs use the UTF-8 encoding. + drive = path.drive + if len(drive) == 2 and drive[1] == ':': + # It's a path on a local drive => 'file:///c:/a/b' + rest = path.as_posix()[2:].lstrip('/') + return 'file:///%s/%s' % ( + drive, urlquote_from_bytes(rest.encode('utf-8'))) + else: + # It's a path on a network drive => 'file://host/share/a/b' + return 'file:' + urlquote_from_bytes(path.as_posix().encode('utf-8')) + + def gethomedir(self, username): + if 'USERPROFILE' in os.environ: + userhome = os.environ['USERPROFILE'] + elif 'HOMEPATH' in os.environ: + try: + drv = os.environ['HOMEDRIVE'] + except KeyError: + drv = '' + userhome = drv + os.environ['HOMEPATH'] + else: + raise RuntimeError("Can't determine home directory") + + if username: + # Try to guess user home directory. By default all users + # directories are located in the same place and are named by + # corresponding usernames. If current user home directory points + # to nonstandard place, this guess is likely wrong. + if os.environ['USERNAME'] != username: + drv, root, parts = self.parse_parts((userhome,)) + if parts[-1] != os.environ['USERNAME']: + raise RuntimeError("Can't determine home directory " + "for %r" % username) + parts[-1] = username + if drv or root: + userhome = drv + root + self.join(parts[1:]) + else: + userhome = self.join(parts) + return userhome + +class _PosixFlavour(_Flavour): + sep = '/' + altsep = '' + has_drv = False + pathmod = posixpath + + is_supported = (os.name != 'nt') + + def splitroot(self, part, sep=sep): + if part and part[0] == sep: + stripped_part = part.lstrip(sep) + # According to POSIX path resolution: + # http://pubs.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap04.html#tag_04_11 + # "A pathname that begins with two successive slashes may be + # interpreted in an implementation-defined manner, although more + # than two leading slashes shall be treated as a single slash". + if len(part) - len(stripped_part) == 2: + return '', sep * 2, stripped_part + else: + return '', sep, stripped_part + else: + return '', '', part + + def casefold(self, s): + return s + + def casefold_parts(self, parts): + return parts + + def compile_pattern(self, pattern): + return re.compile(fnmatch.translate(pattern)).fullmatch + + def resolve(self, path, strict=False): + sep = self.sep + accessor = path._accessor + seen = {} + def _resolve(path, rest): + if rest.startswith(sep): + path = '' + + for name in rest.split(sep): + if not name or name == '.': + # current dir + continue + if name == '..': + # parent dir + path, _, _ = path.rpartition(sep) + continue + if path.endswith(sep): + newpath = path + name + else: + newpath = path + sep + name + if newpath in seen: + # Already seen this path + path = seen[newpath] + if path is not None: + # use cached value + continue + # The symlink is not resolved, so we must have a symlink loop. + raise RuntimeError("Symlink loop from %r" % newpath) + # Resolve the symbolic link + try: + target = accessor.readlink(newpath) + except OSError as e: + if e.errno != EINVAL and strict: + raise + # Not a symlink, or non-strict mode. We just leave the path + # untouched. + path = newpath + else: + seen[newpath] = None # not resolved symlink + path = _resolve(path, target) + seen[newpath] = path # resolved symlink + + return path + # NOTE: according to POSIX, getcwd() cannot contain path components + # which are symlinks. + base = '' if path.is_absolute() else os.getcwd() + return _resolve(base, str(path)) or sep + + def is_reserved(self, parts): + return False + + def make_uri(self, path): + # We represent the path using the local filesystem encoding, + # for portability to other applications. + bpath = bytes(path) + return 'file://' + urlquote_from_bytes(bpath) + + def gethomedir(self, username): + if not username: + try: + return os.environ['HOME'] + except KeyError: + import pwd + return pwd.getpwuid(os.getuid()).pw_dir + else: + import pwd + try: + return pwd.getpwnam(username).pw_dir + except KeyError: + raise RuntimeError("Can't determine home directory " + "for %r" % username) + + +_windows_flavour = _WindowsFlavour() +_posix_flavour = _PosixFlavour() + + +class _Accessor: + """An accessor implements a particular (system-specific or not) way of + accessing paths on the filesystem.""" + + +class _NormalAccessor(_Accessor): + + stat = os.stat + + lstat = os.lstat + + open = os.open + + listdir = os.listdir + + if hasattr(os, "scandir"): + scandir = os.scandir + else: + def scandir(self, pathobj): + raise NotImplementedError("scandir() not available on this system") + + if hasattr(os, "lchmod"): + lchmod = os.lchmod + else: + def lchmod(self, pathobj, mode): + raise NotImplementedError("lchmod() not available on this system") + + if hasattr(os, "chmod"): + chmod = os.chmod + else: + def chmod(self, pathobj, mode): + raise NotImplementedError("chmod() not available on this system") + + mkdir = os.mkdir + + unlink = os.unlink + + if hasattr(os, "link"): + link_to = os.link + else: + @staticmethod + def link_to(self, target): + raise NotImplementedError("os.link() not available on this system") + + rmdir = os.rmdir + + rename = os.rename + + if hasattr(os, "replace"): + replace = os.replace + else: + def replace(self, src, dst): + raise NotImplementedError("replace() not available on this system") + + mkdir = os.mkdir + + if nt: + if supports_symlinks: + symlink = os.symlink + else: + def symlink(a, b, target_is_directory): + raise NotImplementedError("symlink() not available on this system") + else: + # Under POSIX, os.symlink() takes two args + @staticmethod + def symlink(a, b, target_is_directory): + return os.symlink(a, b) + + utime = os.utime + + # Helper for resolve() + def readlink(self, path): + return os.readlink(path) + + def owner(self, path): + try: + import pwd + return pwd.getpwuid(self.stat(path).st_uid).pw_name + except ImportError: + raise NotImplementedError("Path.owner() is unsupported on this system") + + def group(self, path): + try: + import grp + return grp.getgrgid(self.stat(path).st_gid).gr_name + except ImportError: + raise NotImplementedError("Path.group() is unsupported on this system") + + +_normal_accessor = _NormalAccessor() + + +# +# Globbing helpers +# + +def _make_selector(pattern_parts, flavour): + pat = pattern_parts[0] + child_parts = pattern_parts[1:] + if pat == '**': + cls = _RecursiveWildcardSelector + elif '**' in pat: + raise ValueError("Invalid pattern: '**' can only be an entire path component") + elif _is_wildcard_pattern(pat): + cls = _WildcardSelector + else: + cls = _PreciseSelector + return cls(pat, child_parts, flavour) + +if hasattr(functools, "lru_cache"): + _make_selector = functools.lru_cache()(_make_selector) + + +class _Selector: + """A selector matches a specific glob pattern part against the children + of a given path.""" + + def __init__(self, child_parts, flavour): + self.child_parts = child_parts + if child_parts: + self.successor = _make_selector(child_parts, flavour) + self.dironly = True + else: + self.successor = _TerminatingSelector() + self.dironly = False + + def select_from(self, parent_path): + """Iterate over all child paths of `parent_path` matched by this + selector. This can contain parent_path itself.""" + path_cls = type(parent_path) + is_dir = path_cls.is_dir + exists = path_cls.exists + scandir = parent_path._accessor.scandir + if not is_dir(parent_path): + return iter([]) + return self._select_from(parent_path, is_dir, exists, scandir) + + +class _TerminatingSelector: + + def _select_from(self, parent_path, is_dir, exists, scandir): + yield parent_path + + +class _PreciseSelector(_Selector): + + def __init__(self, name, child_parts, flavour): + self.name = name + _Selector.__init__(self, child_parts, flavour) + + def _select_from(self, parent_path, is_dir, exists, scandir): + try: + path = parent_path._make_child_relpath(self.name) + if (is_dir if self.dironly else exists)(path): + for p in self.successor._select_from(path, is_dir, exists, scandir): + yield p + except PermissionError: + return + + +class _WildcardSelector(_Selector): + + def __init__(self, pat, child_parts, flavour): + self.match = flavour.compile_pattern(pat) + _Selector.__init__(self, child_parts, flavour) + + def _select_from(self, parent_path, is_dir, exists, scandir): + try: + with scandir(parent_path) as scandir_it: + entries = list(scandir_it) + for entry in entries: + if self.dironly: + try: + # "entry.is_dir()" can raise PermissionError + # in some cases (see bpo-38894), which is not + # among the errors ignored by _ignore_error() + if not entry.is_dir(): + continue + except OSError as e: + if not _ignore_error(e): + raise + continue + name = entry.name + if self.match(name): + path = parent_path._make_child_relpath(name) + for p in self.successor._select_from(path, is_dir, exists, scandir): + yield p + except PermissionError: + return + + +class _RecursiveWildcardSelector(_Selector): + + def __init__(self, pat, child_parts, flavour): + _Selector.__init__(self, child_parts, flavour) + + def _iterate_directories(self, parent_path, is_dir, scandir): + yield parent_path + try: + with scandir(parent_path) as scandir_it: + entries = list(scandir_it) + for entry in entries: + entry_is_dir = False + try: + entry_is_dir = entry.is_dir() + except OSError as e: + if not _ignore_error(e): + raise + if entry_is_dir and not entry.is_symlink(): + path = parent_path._make_child_relpath(entry.name) + for p in self._iterate_directories(path, is_dir, scandir): + yield p + except PermissionError: + return + + def _select_from(self, parent_path, is_dir, exists, scandir): + try: + yielded = set() + try: + successor_select = self.successor._select_from + for starting_point in self._iterate_directories(parent_path, is_dir, scandir): + for p in successor_select(starting_point, is_dir, exists, scandir): + if p not in yielded: + yield p + yielded.add(p) + finally: + yielded.clear() + except PermissionError: + return + + +# +# Public API +# + +class _PathParents(Sequence): + """This object provides sequence-like access to the logical ancestors + of a path. Don't try to construct it yourself.""" + __slots__ = ('_pathcls', '_drv', '_root', '_parts') + + def __init__(self, path): + # We don't store the instance to avoid reference cycles + self._pathcls = type(path) + self._drv = path._drv + self._root = path._root + self._parts = path._parts + + def __len__(self): + if self._drv or self._root: + return len(self._parts) - 1 + else: + return len(self._parts) + + def __getitem__(self, idx): + if idx < 0 or idx >= len(self): + raise IndexError(idx) + return self._pathcls._from_parsed_parts(self._drv, self._root, + self._parts[:-idx - 1]) + + def __repr__(self): + return "<{}.parents>".format(self._pathcls.__name__) + + +class PurePath(object): + """Base class for manipulating paths without I/O. + + PurePath represents a filesystem path and offers operations which + don't imply any actual filesystem I/O. Depending on your system, + instantiating a PurePath will return either a PurePosixPath or a + PureWindowsPath object. You can also instantiate either of these classes + directly, regardless of your system. + """ + __slots__ = ( + '_drv', '_root', '_parts', + '_str', '_hash', '_pparts', '_cached_cparts', + ) + + def __new__(cls, *args): + """Construct a PurePath from one or several strings and or existing + PurePath objects. The strings and path objects are combined so as + to yield a canonicalized path, which is incorporated into the + new PurePath object. + """ + if cls is PurePath: + cls = PureWindowsPath if os.name == 'nt' else PurePosixPath + return cls._from_parts(args) + + def __reduce__(self): + # Using the parts tuple helps share interned path parts + # when pickling related paths. + return (self.__class__, tuple(self._parts)) + + @classmethod + def _parse_args(cls, args): + # This is useful when you don't want to create an instance, just + # canonicalize some constructor arguments. + parts = [] + for a in args: + if isinstance(a, PurePath): + parts += a._parts + else: + a = os.fspath(a) + if isinstance(a, str): + # Force-cast str subclasses to str (issue #21127) + parts.append(str(a)) + else: + raise TypeError( + "argument should be a str object or an os.PathLike " + "object returning str, not %r" + % type(a)) + return cls._flavour.parse_parts(parts) + + @classmethod + def _from_parts(cls, args, init=True): + # We need to call _parse_args on the instance, so as to get the + # right flavour. + self = object.__new__(cls) + drv, root, parts = self._parse_args(args) + self._drv = drv + self._root = root + self._parts = parts + if init: + self._init() + return self + + @classmethod + def _from_parsed_parts(cls, drv, root, parts, init=True): + self = object.__new__(cls) + self._drv = drv + self._root = root + self._parts = parts + if init: + self._init() + return self + + @classmethod + def _format_parsed_parts(cls, drv, root, parts): + if drv or root: + return drv + root + cls._flavour.join(parts[1:]) + else: + return cls._flavour.join(parts) + + def _init(self): + # Overridden in concrete Path + pass + + def _make_child(self, args): + drv, root, parts = self._parse_args(args) + drv, root, parts = self._flavour.join_parsed_parts( + self._drv, self._root, self._parts, drv, root, parts) + return self._from_parsed_parts(drv, root, parts) + + def __str__(self): + """Return the string representation of the path, suitable for + passing to system calls.""" + try: + return self._str + except AttributeError: + self._str = self._format_parsed_parts(self._drv, self._root, + self._parts) or '.' + return self._str + + def __fspath__(self): + return str(self) + + def as_posix(self): + """Return the string representation of the path with forward (/) + slashes.""" + f = self._flavour + return str(self).replace(f.sep, '/') + + def __bytes__(self): + """Return the bytes representation of the path. This is only + recommended to use under Unix.""" + return os.fsencode(self) + + def __repr__(self): + return "{}({!r})".format(self.__class__.__name__, self.as_posix()) + + def as_uri(self): + """Return the path as a 'file' URI.""" + if not self.is_absolute(): + raise ValueError("relative path can't be expressed as a file URI") + return self._flavour.make_uri(self) + + @property + def _cparts(self): + # Cached casefolded parts, for hashing and comparison + try: + return self._cached_cparts + except AttributeError: + self._cached_cparts = self._flavour.casefold_parts(self._parts) + return self._cached_cparts + + def __eq__(self, other): + if not isinstance(other, PurePath): + return NotImplemented + return self._cparts == other._cparts and self._flavour is other._flavour + + def __hash__(self): + try: + return self._hash + except AttributeError: + self._hash = hash(tuple(self._cparts)) + return self._hash + + def __lt__(self, other): + if not isinstance(other, PurePath) or self._flavour is not other._flavour: + return NotImplemented + return self._cparts < other._cparts + + def __le__(self, other): + if not isinstance(other, PurePath) or self._flavour is not other._flavour: + return NotImplemented + return self._cparts <= other._cparts + + def __gt__(self, other): + if not isinstance(other, PurePath) or self._flavour is not other._flavour: + return NotImplemented + return self._cparts > other._cparts + + def __ge__(self, other): + if not isinstance(other, PurePath) or self._flavour is not other._flavour: + return NotImplemented + return self._cparts >= other._cparts + + def __class_getitem__(cls, type): + return cls + + drive = property(attrgetter('_drv'), + doc="""The drive prefix (letter or UNC path), if any.""") + + root = property(attrgetter('_root'), + doc="""The root of the path, if any.""") + + @property + def anchor(self): + """The concatenation of the drive and root, or ''.""" + anchor = self._drv + self._root + return anchor + + @property + def name(self): + """The final path component, if any.""" + parts = self._parts + if len(parts) == (1 if (self._drv or self._root) else 0): + return '' + return parts[-1] + + @property + def suffix(self): + """ + The final component's last suffix, if any. + + This includes the leading period. For example: '.txt' + """ + name = self.name + i = name.rfind('.') + if 0 < i < len(name) - 1: + return name[i:] + else: + return '' + + @property + def suffixes(self): + """ + A list of the final component's suffixes, if any. + + These include the leading periods. For example: ['.tar', '.gz'] + """ + name = self.name + if name.endswith('.'): + return [] + name = name.lstrip('.') + return ['.' + suffix for suffix in name.split('.')[1:]] + + @property + def stem(self): + """The final path component, minus its last suffix.""" + name = self.name + i = name.rfind('.') + if 0 < i < len(name) - 1: + return name[:i] + else: + return name + + def with_name(self, name): + """Return a new path with the file name changed.""" + if not self.name: + raise ValueError("%r has an empty name" % (self,)) + drv, root, parts = self._flavour.parse_parts((name,)) + if (not name or name[-1] in [self._flavour.sep, self._flavour.altsep] + or drv or root or len(parts) != 1): + raise ValueError("Invalid name %r" % (name)) + return self._from_parsed_parts(self._drv, self._root, + self._parts[:-1] + [name]) + + def with_stem(self, stem): + """Return a new path with the stem changed.""" + return self.with_name(stem + self.suffix) + + def with_suffix(self, suffix): + """Return a new path with the file suffix changed. If the path + has no suffix, add given suffix. If the given suffix is an empty + string, remove the suffix from the path. + """ + f = self._flavour + if f.sep in suffix or f.altsep and f.altsep in suffix: + raise ValueError("Invalid suffix %r" % (suffix,)) + if suffix and not suffix.startswith('.') or suffix == '.': + raise ValueError("Invalid suffix %r" % (suffix)) + name = self.name + if not name: + raise ValueError("%r has an empty name" % (self,)) + old_suffix = self.suffix + if not old_suffix: + name = name + suffix + else: + name = name[:-len(old_suffix)] + suffix + return self._from_parsed_parts(self._drv, self._root, + self._parts[:-1] + [name]) + + def relative_to(self, *other): + """Return the relative path to another path identified by the passed + arguments. If the operation is not possible (because this is not + a subpath of the other path), raise ValueError. + """ + # For the purpose of this method, drive and root are considered + # separate parts, i.e.: + # Path('c:/').relative_to('c:') gives Path('/') + # Path('c:/').relative_to('/') raise ValueError + if not other: + raise TypeError("need at least one argument") + parts = self._parts + drv = self._drv + root = self._root + if root: + abs_parts = [drv, root] + parts[1:] + else: + abs_parts = parts + to_drv, to_root, to_parts = self._parse_args(other) + if to_root: + to_abs_parts = [to_drv, to_root] + to_parts[1:] + else: + to_abs_parts = to_parts + n = len(to_abs_parts) + cf = self._flavour.casefold_parts + if (root or drv) if n == 0 else cf(abs_parts[:n]) != cf(to_abs_parts): + formatted = self._format_parsed_parts(to_drv, to_root, to_parts) + raise ValueError("{!r} is not in the subpath of {!r}" + " OR one path is relative and the other is absolute." + .format(str(self), str(formatted))) + return self._from_parsed_parts('', root if n == 1 else '', + abs_parts[n:]) + + def is_relative_to(self, *other): + """Return True if the path is relative to another path or False. + """ + try: + self.relative_to(*other) + return True + except ValueError: + return False + + @property + def parts(self): + """An object providing sequence-like access to the + components in the filesystem path.""" + # We cache the tuple to avoid building a new one each time .parts + # is accessed. XXX is this necessary? + try: + return self._pparts + except AttributeError: + self._pparts = tuple(self._parts) + return self._pparts + + def joinpath(self, *args): + """Combine this path with one or several arguments, and return a + new path representing either a subpath (if all arguments are relative + paths) or a totally different path (if one of the arguments is + anchored). + """ + return self._make_child(args) + + def __truediv__(self, key): + try: + return self._make_child((key,)) + except TypeError: + return NotImplemented + + def __rtruediv__(self, key): + try: + return self._from_parts([key] + self._parts) + except TypeError: + return NotImplemented + + @property + def parent(self): + """The logical parent of the path.""" + drv = self._drv + root = self._root + parts = self._parts + if len(parts) == 1 and (drv or root): + return self + return self._from_parsed_parts(drv, root, parts[:-1]) + + @property + def parents(self): + """A sequence of this path's logical parents.""" + return _PathParents(self) + + def is_absolute(self): + """True if the path is absolute (has both a root and, if applicable, + a drive).""" + if not self._root: + return False + return not self._flavour.has_drv or bool(self._drv) + + def is_reserved(self): + """Return True if the path contains one of the special names reserved + by the system, if any.""" + return self._flavour.is_reserved(self._parts) + + def match(self, path_pattern): + """ + Return True if this path matches the given pattern. + """ + cf = self._flavour.casefold + path_pattern = cf(path_pattern) + drv, root, pat_parts = self._flavour.parse_parts((path_pattern,)) + if not pat_parts: + raise ValueError("empty pattern") + if drv and drv != cf(self._drv): + return False + if root and root != cf(self._root): + return False + parts = self._cparts + if drv or root: + if len(pat_parts) != len(parts): + return False + pat_parts = pat_parts[1:] + elif len(pat_parts) > len(parts): + return False + for part, pat in zip(reversed(parts), reversed(pat_parts)): + if not fnmatch.fnmatchcase(part, pat): + return False + return True + +# Can't subclass os.PathLike from PurePath and keep the constructor +# optimizations in PurePath._parse_args(). +try: + os.PathLike.register(PurePath) +except AttributeError: + pass + + +class PurePosixPath(PurePath): + """PurePath subclass for non-Windows systems. + + On a POSIX system, instantiating a PurePath should return this object. + However, you can also instantiate it directly on any system. + """ + _flavour = _posix_flavour + __slots__ = () + + +class PureWindowsPath(PurePath): + """PurePath subclass for Windows systems. + + On a Windows system, instantiating a PurePath should return this object. + However, you can also instantiate it directly on any system. + """ + _flavour = _windows_flavour + __slots__ = () + + +# Filesystem-accessing classes + + +class Path(PurePath): + """PurePath subclass that can make system calls. + + Path represents a filesystem path but unlike PurePath, also offers + methods to do system calls on path objects. Depending on your system, + instantiating a Path will return either a PosixPath or a WindowsPath + object. You can also instantiate a PosixPath or WindowsPath directly, + but cannot instantiate a WindowsPath on a POSIX system or vice versa. + """ + __slots__ = ( + '_accessor', + ) + + def __new__(cls, *args, **kwargs): + if cls is Path: + cls = WindowsPath if os.name == 'nt' else PosixPath + self = cls._from_parts(args, init=False) + if not self._flavour.is_supported: + raise NotImplementedError("cannot instantiate %r on your system" + % (cls.__name__,)) + self._init() + return self + + def _init(self, + # Private non-constructor arguments + template=None, + ): + if template is not None: + self._accessor = template._accessor + else: + self._accessor = _normal_accessor + + def _make_child_relpath(self, part): + # This is an optimization used for dir walking. `part` must be + # a single part relative to this path. + parts = self._parts + [part] + return self._from_parsed_parts(self._drv, self._root, parts) + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + # https://bugs.python.org/issue39682 + # In previous versions of pathlib, this method marked this path as + # closed; subsequent attempts to perform I/O would raise an IOError. + # This functionality was never documented, and had the effect of + # making Path objects mutable, contrary to PEP 428. In Python 3.9 the + # _closed attribute was removed, and this method made a no-op. + # This method and __enter__()/__exit__() should be deprecated and + # removed in the future. + pass + + def _opener(self, name, flags, mode=0o666): + # A stub for the opener argument to built-in open() + return self._accessor.open(self, flags, mode) + + def _raw_open(self, flags, mode=0o777): + """ + Open the file pointed by this path and return a file descriptor, + as os.open() does. + """ + return self._accessor.open(self, flags, mode) + + # Public API + + @classmethod + def cwd(cls): + """Return a new path pointing to the current working directory + (as returned by os.getcwd()). + """ + return cls(os.getcwd()) + + @classmethod + def home(cls): + """Return a new path pointing to the user's home directory (as + returned by os.path.expanduser('~')). + """ + return cls(cls()._flavour.gethomedir(None)) + + def samefile(self, other_path): + """Return whether other_path is the same or not as this file + (as returned by os.path.samefile()). + """ + st = self.stat() + try: + other_st = other_path.stat() + except AttributeError: + other_st = self._accessor.stat(other_path) + return os.path.samestat(st, other_st) + + def iterdir(self): + """Iterate over the files in this directory. Does not yield any + result for the special paths '.' and '..'. + """ + for name in self._accessor.listdir(self): + if name in {'.', '..'}: + # Yielding a path object for these makes little sense + continue + yield self._make_child_relpath(name) + + def glob(self, pattern): + """Iterate over this subtree and yield all existing files (of any + kind, including directories) matching the given relative pattern. + """ + sys.audit("pathlib.Path.glob", self, pattern) + if not pattern: + raise ValueError("Unacceptable pattern: {!r}".format(pattern)) + drv, root, pattern_parts = self._flavour.parse_parts((pattern,)) + if drv or root: + raise NotImplementedError("Non-relative patterns are unsupported") + selector = _make_selector(tuple(pattern_parts), self._flavour) + for p in selector.select_from(self): + yield p + + def rglob(self, pattern): + """Recursively yield all existing files (of any kind, including + directories) matching the given relative pattern, anywhere in + this subtree. + """ + sys.audit("pathlib.Path.rglob", self, pattern) + drv, root, pattern_parts = self._flavour.parse_parts((pattern,)) + if drv or root: + raise NotImplementedError("Non-relative patterns are unsupported") + selector = _make_selector(("**",) + tuple(pattern_parts), self._flavour) + for p in selector.select_from(self): + yield p + + def absolute(self): + """Return an absolute version of this path. This function works + even if the path doesn't point to anything. + + No normalization is done, i.e. all '.' and '..' will be kept along. + Use resolve() to get the canonical path to a file. + """ + # XXX untested yet! + if self.is_absolute(): + return self + # FIXME this must defer to the specific flavour (and, under Windows, + # use nt._getfullpathname()) + obj = self._from_parts([os.getcwd()] + self._parts, init=False) + obj._init(template=self) + return obj + + def resolve(self, strict=False): + """ + Make the path absolute, resolving all symlinks on the way and also + normalizing it (for example turning slashes into backslashes under + Windows). + """ + s = self._flavour.resolve(self, strict=strict) + if s is None: + # No symlink resolution => for consistency, raise an error if + # the path doesn't exist or is forbidden + self.stat() + s = str(self.absolute()) + # Now we have no symlinks in the path, it's safe to normalize it. + normed = self._flavour.pathmod.normpath(s) + obj = self._from_parts((normed,), init=False) + obj._init(template=self) + return obj + + def stat(self): + """ + Return the result of the stat() system call on this path, like + os.stat() does. + """ + return self._accessor.stat(self) + + def owner(self): + """ + Return the login name of the file owner. + """ + return self._accessor.owner(self) + + def group(self): + """ + Return the group name of the file gid. + """ + return self._accessor.group(self) + + def open(self, mode='r', buffering=-1, encoding=None, + errors=None, newline=None): + """ + Open the file pointed by this path and return a file object, as + the built-in open() function does. + """ + return io.open(self, mode, buffering, encoding, errors, newline, + opener=self._opener) + + def read_bytes(self): + """ + Open the file in bytes mode, read it, and close the file. + """ + with self.open(mode='rb') as f: + return f.read() + + def read_text(self, encoding=None, errors=None): + """ + Open the file in text mode, read it, and close the file. + """ + with self.open(mode='r', encoding=encoding, errors=errors) as f: + return f.read() + + def write_bytes(self, data): + """ + Open the file in bytes mode, write to it, and close the file. + """ + # type-check for the buffer interface before truncating the file + view = memoryview(data) + with self.open(mode='wb') as f: + return f.write(view) + + def write_text(self, data, encoding=None, errors=None): + """ + Open the file in text mode, write to it, and close the file. + """ + if not isinstance(data, str): + raise TypeError('data must be str, not %s' % + data.__class__.__name__) + with self.open(mode='w', encoding=encoding, errors=errors) as f: + return f.write(data) + + def readlink(self): + """ + Return the path to which the symbolic link points. + """ + path = self._accessor.readlink(self) + obj = self._from_parts((path,), init=False) + obj._init(template=self) + return obj + + def touch(self, mode=0o666, exist_ok=True): + """ + Create this file with the given access mode, if it doesn't exist. + """ + if exist_ok: + # First try to bump modification time + # Implementation note: GNU touch uses the UTIME_NOW option of + # the utimensat() / futimens() functions. + try: + self._accessor.utime(self, None) + except OSError: + # Avoid exception chaining + pass + else: + return + flags = os.O_CREAT | os.O_WRONLY + if not exist_ok: + flags |= os.O_EXCL + fd = self._raw_open(flags, mode) + os.close(fd) + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + """ + Create a new directory at this given path. + """ + try: + self._accessor.mkdir(self, mode) + except FileNotFoundError: + if not parents or self.parent == self: + raise + self.parent.mkdir(parents=True, exist_ok=True) + self.mkdir(mode, parents=False, exist_ok=exist_ok) + except OSError: + # Cannot rely on checking for EEXIST, since the operating system + # could give priority to other errors like EACCES or EROFS + if not exist_ok or not self.is_dir(): + raise + + def chmod(self, mode): + """ + Change the permissions of the path, like os.chmod(). + """ + self._accessor.chmod(self, mode) + + def lchmod(self, mode): + """ + Like chmod(), except if the path points to a symlink, the symlink's + permissions are changed, rather than its target's. + """ + self._accessor.lchmod(self, mode) + + def unlink(self, missing_ok=False): + """ + Remove this file or link. + If the path is a directory, use rmdir() instead. + """ + try: + self._accessor.unlink(self) + except FileNotFoundError: + if not missing_ok: + raise + + def rmdir(self): + """ + Remove this directory. The directory must be empty. + """ + self._accessor.rmdir(self) + + def lstat(self): + """ + Like stat(), except if the path points to a symlink, the symlink's + status information is returned, rather than its target's. + """ + return self._accessor.lstat(self) + + def link_to(self, target): + """ + Create a hard link pointing to a path named target. + """ + self._accessor.link_to(self, target) + + def rename(self, target): + """ + Rename this path to the target path. + + The target path may be absolute or relative. Relative paths are + interpreted relative to the current working directory, *not* the + directory of the Path object. + + Returns the new Path instance pointing to the target path. + """ + self._accessor.rename(self, target) + return self.__class__(target) + + def replace(self, target): + """ + Rename this path to the target path, overwriting if that path exists. + + The target path may be absolute or relative. Relative paths are + interpreted relative to the current working directory, *not* the + directory of the Path object. + + Returns the new Path instance pointing to the target path. + """ + self._accessor.replace(self, target) + return self.__class__(target) + + def symlink_to(self, target, target_is_directory=False): + """ + Make this path a symlink pointing to the given path. + Note the order of arguments (self, target) is the reverse of os.symlink's. + """ + self._accessor.symlink(target, self, target_is_directory) + + # Convenience functions for querying the stat results + + def exists(self): + """ + Whether this path exists. + """ + try: + self.stat() + except OSError as e: + if not _ignore_error(e): + raise + return False + except ValueError: + # Non-encodable path + return False + return True + + def is_dir(self): + """ + Whether this path is a directory. + """ + try: + return S_ISDIR(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def is_file(self): + """ + Whether this path is a regular file (also True for symlinks pointing + to regular files). + """ + try: + return S_ISREG(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def is_mount(self): + """ + Check if this path is a POSIX mount point + """ + # Need to exist and be a dir + if not self.exists() or not self.is_dir(): + return False + + try: + parent_dev = self.parent.stat().st_dev + except OSError: + return False + + dev = self.stat().st_dev + if dev != parent_dev: + return True + ino = self.stat().st_ino + parent_ino = self.parent.stat().st_ino + return ino == parent_ino + + def is_symlink(self): + """ + Whether this path is a symbolic link. + """ + try: + return S_ISLNK(self.lstat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist + return False + except ValueError: + # Non-encodable path + return False + + def is_block_device(self): + """ + Whether this path is a block device. + """ + try: + return S_ISBLK(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def is_char_device(self): + """ + Whether this path is a character device. + """ + try: + return S_ISCHR(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def is_fifo(self): + """ + Whether this path is a FIFO. + """ + try: + return S_ISFIFO(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def is_socket(self): + """ + Whether this path is a socket. + """ + try: + return S_ISSOCK(self.stat().st_mode) + except OSError as e: + if not _ignore_error(e): + raise + # Path doesn't exist or is a broken symlink + # (see https://bitbucket.org/pitrou/pathlib/issue/12/) + return False + except ValueError: + # Non-encodable path + return False + + def expanduser(self): + """ Return a new path with expanded ~ and ~user constructs + (as returned by os.path.expanduser) + """ + if (not (self._drv or self._root) and + self._parts and self._parts[0][:1] == '~'): + homedir = self._flavour.gethomedir(self._parts[0][1:]) + return self._from_parts([homedir] + self._parts[1:]) + + return self + + +class PosixPath(Path, PurePosixPath): + """Path subclass for non-Windows systems. + + On a POSIX system, instantiating a Path should return this object. + """ + __slots__ = () + +class WindowsPath(Path, PureWindowsPath): + """Path subclass for Windows systems. + + On a Windows system, instantiating a Path should return this object. + """ + __slots__ = () + + def is_mount(self): + raise NotImplementedError("Path.is_mount() is unsupported on this system") From 83c5065d62398bf1cd32aa89dece5a380ada8215 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:33:51 +0200 Subject: [PATCH 45/65] Add random.random() --- random/metadata.txt | 2 +- random/random.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/random/metadata.txt b/random/metadata.txt index e664e0316c..06e7d08c4e 100644 --- a/random/metadata.txt +++ b/random/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.2.4 +version = 0.3.0 diff --git a/random/random.py b/random/random.py index 8219b168cd..1eca6121b8 100644 --- a/random/random.py +++ b/random/random.py @@ -24,6 +24,10 @@ def randint(start, stop): uniform = randint +def random(): + # single-precision float mantissa is 23 bits, add one to be safe + return getrandbits(24) / (1<<24) + def shuffle(seq): l = len(seq) if l < 2: @@ -39,6 +43,10 @@ def choice(seq): class Random: + @staticmethod + def random(): + return random() + @staticmethod def randrange(start, stop=None): return randrange(start, stop) From 75eecabd3b02a76f32a9bfcfbf3fe643644d45e1 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:34:34 +0200 Subject: [PATCH 46/65] don't use const() for non-integers micropython doesn't like that --- re-pcre/metadata.txt | 2 +- re-pcre/re.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/re-pcre/metadata.txt b/re-pcre/metadata.txt index 7e0ad21614..ca15c512b7 100644 --- a/re-pcre/metadata.txt +++ b/re-pcre/metadata.txt @@ -1,6 +1,6 @@ name = re srctype = pycopy-lib type = module -version = 0.9.7 +version = 0.9.8 author = Paul Sokolovsky depends = ffilib diff --git a/re-pcre/re.py b/re-pcre/re.py index 13ee12aa94..dc274cd1c5 100644 --- a/re-pcre/re.py +++ b/re-pcre/re.py @@ -67,10 +67,10 @@ PCRE_INFO_CAPTURECOUNT = 2 -_UNESCAPE_DICT = const({ +_UNESCAPE_DICT = { b"\\": b"\\", b"n": b"\n", b"r": b"\r", b"t": b"\t", b"v": b"\v", b"f": b"\f", b"a": b"\a", b"b": b"\b" -}) +} class error(Exception): From 884f80d6276c404d85388a62f94a456a47b7c360 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:35:29 +0200 Subject: [PATCH 47/65] sys: simply import usys --- sys/metadata.txt | 2 +- sys/sys.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sys/metadata.txt b/sys/metadata.txt index 976088c8a2..dc5f60a661 100644 --- a/sys/metadata.txt +++ b/sys/metadata.txt @@ -1,3 +1,3 @@ srctype = dummy type = module -version = 0.0.0 +version = 0.0.1 diff --git a/sys/sys.py b/sys/sys.py index e69de29bb2..f163572787 100644 --- a/sys/sys.py +++ b/sys/sys.py @@ -0,0 +1 @@ +from usys import * From eb55796a3b66ddc35c85a4323e1117bf4c6274ee Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:36:51 +0200 Subject: [PATCH 48/65] typing: "implement" a heap of subscriptable classes also TYPE_CHECKING is False by definition for this to actually work, micropython needs to implement __class_getitem__ cf. github.com/smurfix/micropython --- typing/metadata.txt | 2 +- typing/typing.py | 66 ++++++++++++++++++++++++--------------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/typing/metadata.txt b/typing/metadata.txt index 93d12b4213..06e7d08c4e 100644 --- a/typing/metadata.txt +++ b/typing/metadata.txt @@ -1,3 +1,3 @@ srctype = pycopy-lib type = module -version = 0.2.2 +version = 0.3.0 diff --git a/typing/typing.py b/typing/typing.py index 811ead9621..8db6118cea 100644 --- a/typing/typing.py +++ b/typing/typing.py @@ -1,11 +1,16 @@ +TYPE_CHECKING = False + class _Subscriptable: def __getitem__(self, sub): - return None + return _SubSingleton + + def __class_getitem__(self, sub): + return _Subscriptable _SubSingleton = _Subscriptable() -def TypeVar(new_type, *types): +def TypeVar(new_type, *types, **kw): return None class Any: pass @@ -14,40 +19,41 @@ class NoReturn: pass class ClassVar: pass Union = _SubSingleton Optional = _SubSingleton -Generic = _SubSingleton +Generic = _Subscriptable NamedTuple = _SubSingleton class Hashable: pass class Awaitable: pass class Coroutine: pass class AsyncIterable: pass class AsyncIterator: pass -class Iterable: pass -class Iterator: pass -class Reversible: pass -class Sized: pass -class Container: pass -class Collection: pass -Callable = _SubSingleton -AbstractSet = _SubSingleton -MutableSet = _SubSingleton -Mapping = _SubSingleton -MutableMapping = _SubSingleton -Sequence = _SubSingleton -MutableSequence = _SubSingleton -class ByteString: pass -Tuple = _SubSingleton -List = _SubSingleton -class Deque: pass -Set = _SubSingleton -FrozenSet = _SubSingleton -class MappingView: pass -class KeysView: pass -class ItemsView: pass -class ValuesView: pass -class ContextManager: pass -class AsyncContextManager: pass -Dict = _SubSingleton -DefaultDict = _SubSingleton +Iterable = _Subscriptable +Iterator = _Subscriptable +Literal = _Subscriptable +Reversible = _Subscriptable +Sized = _Subscriptable +Container = _Subscriptable +Collection = _Subscriptable +Callable = _Subscriptable +AbstractSet = _Subscriptable +MutableSet = _Subscriptable +Mapping = _Subscriptable +MutableMapping = _Subscriptable +Sequence = _Subscriptable +MutableSequence = _Subscriptable +ByteString = _Subscriptable +Tuple = _Subscriptable +List = _Subscriptable +Deque = _Subscriptable +Set = _Subscriptable +FrozenSet = _Subscriptable +MappingView = _Subscriptable +KeysView = _Subscriptable +ItemsView = _Subscriptable +ValuesView = _Subscriptable +ContextManager = _Subscriptable +AsyncContextManager = _Subscriptable +Dict = _Subscriptable +DefaultDict = _Subscriptable class Counter: pass class ChainMap: pass class Generator: pass From 101f73561ef9f0c54bc61e63eac922f0aa843b6f Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 10 May 2021 17:39:43 +0200 Subject: [PATCH 49/65] added dummies for Weak*Dictionary and WeakSet useful for quick&dirty getting-code-to-work unfortunately micropython cannot subclass dict, otherwise we'd emit a warning when these are used --- weakref/metadata.txt | 2 +- weakref/weakref.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/weakref/metadata.txt b/weakref/metadata.txt index fda992a9c0..e290347afe 100644 --- a/weakref/metadata.txt +++ b/weakref/metadata.txt @@ -1,3 +1,3 @@ srctype=dummy type=module -version = 0.0.2 +version = 0.0.3 diff --git a/weakref/weakref.py b/weakref/weakref.py index 76aabfa316..411fb56391 100644 --- a/weakref/weakref.py +++ b/weakref/weakref.py @@ -5,3 +5,7 @@ def proxy(obj, cb=None): return obj + +WeakKeyDictionary = dict +WeakValueDictionary = dict +WeakSet = set From b7e8b9951e01980a5d9177feabbadbbf8d381d83 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Tue, 18 May 2021 09:09:05 +0200 Subject: [PATCH 50/65] datetime: Fix TZ object initialization --- datetime/datetime.py | 2 +- datetime/metadata.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datetime/datetime.py b/datetime/datetime.py index 03b4e4cdd1..cfbfdb611a 100644 --- a/datetime/datetime.py +++ b/datetime/datetime.py @@ -1833,7 +1833,7 @@ def __new__(cls, offset, name=_Omitted): @classmethod def _create(cls, offset, name=None): - self = tzinfo.__new__(cls) + self = object.__new__(cls) self._offset = offset self._name = name return self diff --git a/datetime/metadata.txt b/datetime/metadata.txt index 950962aefb..cb5ffa3274 100644 --- a/datetime/metadata.txt +++ b/datetime/metadata.txt @@ -1,3 +1,3 @@ srctype = cpython type = module -version = 3.3.3-1 +version = 3.3.3-2 From 8f46fa350189a214766eeef38d2f01647931bda2 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Tue, 18 May 2021 09:09:51 +0200 Subject: [PATCH 51/65] os: workaround when `bytestring_at` is unavailable --- os/metadata.txt | 2 +- os/os/__init__.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/os/metadata.txt b/os/metadata.txt index 575cd1937a..c710c67244 100644 --- a/os/metadata.txt +++ b/os/metadata.txt @@ -1,5 +1,5 @@ srctype = pycopy-lib type = package -version = 1.1.2 +version = 1.1.3 author = Paul Sokolovsky depends = errno, stat, os.path, uos2 diff --git a/os/os/__init__.py b/os/os/__init__.py index 3f972c457a..2ec104c699 100644 --- a/os/os/__init__.py +++ b/os/os/__init__.py @@ -265,11 +265,25 @@ class _Environ(object): def __init__(self): self._data = dict() env = uctypes.struct(_environ_ptr.get(), _ENV_STRUCT) + try: + uctypes.bytestring_at + except AttributeError: + def getter(addr): + n = 0 + while True: + if addr[n] == 0: + break + n += 1 + return uctypes.bytearray_at(int(addr),n).decode() + else: + def getter(addr): + return uctypes.bytestring_at(int(addr)).decode() + + for i in range(4096): if int(env.arr[i]) == 0: break - # requires micropython change f20a730 - s = uctypes.bytestring_at(int(env.arr[i])).decode() + s = getter(int(env.arr[i])).decode() k, v = s.split("=", 1) self._data[k] = v self.__getitem__ = self._data.__getitem__ From 55d5a172a2ef41f02c294d84646c893241ab138c Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Tue, 18 May 2021 09:12:27 +0200 Subject: [PATCH 52/65] Rename umqtt.simple to just umqtt --- umqtt.robust/umqtt/robust.py | 4 ++-- {umqtt.simple => umqtt}/README.rst | 0 {umqtt.simple => umqtt}/example_pub.py | 0 {umqtt.simple => umqtt}/example_pub_button.py | 0 {umqtt.simple => umqtt}/example_sub.py | 0 {umqtt.simple => umqtt}/example_sub_led.py | 0 {umqtt.simple => umqtt}/metadata.txt | 2 +- umqtt.simple/umqtt/simple.py => umqtt/umqtt/__init__.py | 0 8 files changed, 3 insertions(+), 3 deletions(-) rename {umqtt.simple => umqtt}/README.rst (100%) rename {umqtt.simple => umqtt}/example_pub.py (100%) rename {umqtt.simple => umqtt}/example_pub_button.py (100%) rename {umqtt.simple => umqtt}/example_sub.py (100%) rename {umqtt.simple => umqtt}/example_sub_led.py (100%) rename {umqtt.simple => umqtt}/metadata.txt (88%) rename umqtt.simple/umqtt/simple.py => umqtt/umqtt/__init__.py (100%) diff --git a/umqtt.robust/umqtt/robust.py b/umqtt.robust/umqtt/robust.py index 7ee40e0205..40f4e4b75a 100644 --- a/umqtt.robust/umqtt/robust.py +++ b/umqtt.robust/umqtt/robust.py @@ -1,7 +1,7 @@ import utime -from . import simple +from . import MQTTClient as BaseMQTTClient -class MQTTClient(simple.MQTTClient): +class MQTTClient(BaseMQTTClient): DELAY = 2 DEBUG = False diff --git a/umqtt.simple/README.rst b/umqtt/README.rst similarity index 100% rename from umqtt.simple/README.rst rename to umqtt/README.rst diff --git a/umqtt.simple/example_pub.py b/umqtt/example_pub.py similarity index 100% rename from umqtt.simple/example_pub.py rename to umqtt/example_pub.py diff --git a/umqtt.simple/example_pub_button.py b/umqtt/example_pub_button.py similarity index 100% rename from umqtt.simple/example_pub_button.py rename to umqtt/example_pub_button.py diff --git a/umqtt.simple/example_sub.py b/umqtt/example_sub.py similarity index 100% rename from umqtt.simple/example_sub.py rename to umqtt/example_sub.py diff --git a/umqtt.simple/example_sub_led.py b/umqtt/example_sub_led.py similarity index 100% rename from umqtt.simple/example_sub_led.py rename to umqtt/example_sub_led.py diff --git a/umqtt.simple/metadata.txt b/umqtt/metadata.txt similarity index 88% rename from umqtt.simple/metadata.txt rename to umqtt/metadata.txt index 00128649c2..580a2b91b8 100644 --- a/umqtt.simple/metadata.txt +++ b/umqtt/metadata.txt @@ -1,6 +1,6 @@ srctype = pycopy-lib type = package -version = 1.3.6 +version = 1.3.7 author = Paul Sokolovsky desc = Lightweight MQTT client for Pycopy long_desc = README.rst diff --git a/umqtt.simple/umqtt/simple.py b/umqtt/umqtt/__init__.py similarity index 100% rename from umqtt.simple/umqtt/simple.py rename to umqtt/umqtt/__init__.py From e78e4542285bda98fd823834d5ec22e9a5486f98 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 23 May 2021 12:58:25 +0200 Subject: [PATCH 53/65] fix dep --- umqtt.robust/metadata.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/umqtt.robust/metadata.txt b/umqtt.robust/metadata.txt index 859f35e268..bbabd74b27 100644 --- a/umqtt.robust/metadata.txt +++ b/umqtt.robust/metadata.txt @@ -4,3 +4,4 @@ version = 1.0.2 author = Paul Sokolovsky desc = Lightweight MQTT client for Pycopy ("robust" version). long_desc = README.rst +depends = umqtt From 6d84ba2ecfe39644b2a4feef787c86b2b91a8f69 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sun, 23 May 2021 12:58:33 +0200 Subject: [PATCH 54/65] fix os.environ --- os/os/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/os/os/__init__.py b/os/os/__init__.py index 2ec104c699..c352dda550 100644 --- a/os/os/__init__.py +++ b/os/os/__init__.py @@ -271,7 +271,7 @@ def __init__(self): def getter(addr): n = 0 while True: - if addr[n] == 0: + if uctypes.bytes_at(int(addr)+n,1)[0] == 0: break n += 1 return uctypes.bytearray_at(int(addr),n).decode() @@ -283,7 +283,7 @@ def getter(addr): for i in range(4096): if int(env.arr[i]) == 0: break - s = getter(int(env.arr[i])).decode() + s = getter(int(env.arr[i])) k, v = s.split("=", 1) self._data[k] = v self.__getitem__ = self._data.__getitem__ From 5b613a38615fe375cc102153c9f3e5e739692ca4 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 27 Jun 2022 13:28:09 +0200 Subject: [PATCH 55/65] missing import --- contextlib/contextlib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/contextlib/contextlib.py b/contextlib/contextlib.py index e33d20fc31..57d00aa724 100644 --- a/contextlib/contextlib.py +++ b/contextlib/contextlib.py @@ -10,6 +10,7 @@ import sys from collections import deque from ucontextlib import * +from functools import wraps class AbstractAsyncContextManager(object): From d47e13341777b61e9ce478b6bfc5bb205b86a960 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 27 Jun 2022 13:28:59 +0200 Subject: [PATCH 56/65] missing import --- time/time.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/time/time.py b/time/time.py index c45968bd95..568be39d91 100644 --- a/time/time.py +++ b/time/time.py @@ -8,6 +8,8 @@ libc = ffilib.libc() +from utime import ticks_add, ticks_diff, ticks_ms, ticks_us, ticks_cpu + # struct tm *gmtime(const time_t *timep); # struct tm *localtime(const time_t *timep); # size_t strftime(char *s, size_t max, const char *format, From 615050e344c2a56dc204c998c2cd277b779e1ebe Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 27 Jun 2022 13:29:18 +0200 Subject: [PATCH 57/65] make shorter! --- datetime/datetime.py | 521 ------------------------------------------- 1 file changed, 521 deletions(-) diff --git a/datetime/datetime.py b/datetime/datetime.py index cfbfdb611a..41cc42560a 100644 --- a/datetime/datetime.py +++ b/datetime/datetime.py @@ -1,9 +1,3 @@ -"""Concrete date/time and related types. - -See http://www.iana.org/time-zones/repository/tz-link.html for -time zone and DST data sources. -""" - import time as _time import math as _math @@ -14,15 +8,6 @@ def _cmp(x, y): MAXYEAR = 9999 _MAXORDINAL = 3652059 # date.max.toordinal() -# Utility functions, adapted from Python's Demo/classes/Dates.py, which -# also assumes the current Gregorian calendar indefinitely extended in -# both directions. Difference: Dates.py calls January 1 of year 0 day -# number 1. The code here calls January 1 of year 1 day number 1. This is -# to match the definition of the "proleptic Gregorian" calendar in Dershowitz -# and Reingold's "Calendrical Calculations", where it's the base calendar -# for all computations. See the book for algorithms for converting between -# proleptic Gregorian ordinals and many other calendar systems. - _DAYS_IN_MONTH = [None, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] _DAYS_BEFORE_MONTH = [None] @@ -43,21 +28,17 @@ def _days_before_year(year): def _days_in_month(year, month): "year, month -> number of days in that month in that year." - assert 1 <= month <= 12, month if month == 2 and _is_leap(year): return 29 return _DAYS_IN_MONTH[month] def _days_before_month(year, month): "year, month -> number of days in year preceding first day of month." - assert 1 <= month <= 12, 'month must be in 1..12' return _DAYS_BEFORE_MONTH[month] + (month > 2 and _is_leap(year)) def _ymd2ord(year, month, day): "year, month, day -> ordinal, considering 01-Jan-0001 as day 1." - assert 1 <= month <= 12, 'month must be in 1..12' dim = _days_in_month(year, month) - assert 1 <= day <= dim, ('day must be in 1..%d' % dim) return (_days_before_year(year) + _days_before_month(year, month) + day) @@ -66,81 +47,31 @@ def _ymd2ord(year, month, day): _DI100Y = _days_before_year(101) # " " " " 100 " _DI4Y = _days_before_year(5) # " " " " 4 " -# A 4-year cycle has an extra leap day over what we'd get from pasting -# together 4 single years. -assert _DI4Y == 4 * 365 + 1 - -# Similarly, a 400-year cycle has an extra leap day over what we'd get from -# pasting together 4 100-year cycles. -assert _DI400Y == 4 * _DI100Y + 1 - -# OTOH, a 100-year cycle has one fewer leap day than we'd get from -# pasting together 25 4-year cycles. -assert _DI100Y == 25 * _DI4Y - 1 - def _ord2ymd(n): - "ordinal -> (year, month, day), considering 01-Jan-0001 as day 1." - - # n is a 1-based index, starting at 1-Jan-1. The pattern of leap years - # repeats exactly every 400 years. The basic strategy is to find the - # closest 400-year boundary at or before n, then work with the offset - # from that boundary to n. Life is much clearer if we subtract 1 from - # n first -- then the values of n at 400-year boundaries are exactly - # those divisible by _DI400Y: - # - # D M Y n n-1 - # -- --- ---- ---------- ---------------- - # 31 Dec -400 -_DI400Y -_DI400Y -1 - # 1 Jan -399 -_DI400Y +1 -_DI400Y 400-year boundary - # ... - # 30 Dec 000 -1 -2 - # 31 Dec 000 0 -1 - # 1 Jan 001 1 0 400-year boundary - # 2 Jan 001 2 1 - # 3 Jan 001 3 2 - # ... - # 31 Dec 400 _DI400Y _DI400Y -1 - # 1 Jan 401 _DI400Y +1 _DI400Y 400-year boundary n -= 1 n400, n = divmod(n, _DI400Y) year = n400 * 400 + 1 # ..., -399, 1, 401, ... - # Now n is the (non-negative) offset, in days, from January 1 of year, to - # the desired date. Now compute how many 100-year cycles precede n. - # Note that it's possible for n100 to equal 4! In that case 4 full - # 100-year cycles precede the desired day, which implies the desired - # day is December 31 at the end of a 400-year cycle. n100, n = divmod(n, _DI100Y) - # Now compute how many 4-year cycles precede it. n4, n = divmod(n, _DI4Y) - # And now how many single years. Again n1 can be 4, and again meaning - # that the desired day is December 31 at the end of the 4-year cycle. n1, n = divmod(n, 365) year += n100 * 100 + n4 * 4 + n1 if n1 == 4 or n100 == 4: - assert n == 0 return year-1, 12, 31 - # Now the year is correct, and n is the offset from January 1. We find - # the month via an estimate that's either exact or one too large. leapyear = n1 == 3 and (n4 != 24 or n100 == 3) - assert leapyear == _is_leap(year) month = (n + 50) >> 5 preceding = _DAYS_BEFORE_MONTH[month] + (month > 2 and leapyear) if preceding > n: # estimate is too large month -= 1 preceding -= _DAYS_IN_MONTH[month] + (month == 2 and leapyear) n -= preceding - assert 0 <= n < _days_in_month(year, month) - # Now the year and month are correct, and n is the offset from the - # start of that month: we're done! return year, month, n+1 -# Month and day names. For localized versions, see the calendar module. _MONTHNAMES = [None, "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] _DAYNAMES = [None, "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] @@ -152,20 +83,16 @@ def _build_struct_time(y, m, d, hh, mm, ss, dstflag): return _time.struct_time((y, m, d, hh, mm, ss, wday, dnum, dstflag)) def _format_time(hh, mm, ss, us): - # Skip trailing microseconds when us==0. result = "%02d:%02d:%02d" % (hh, mm, ss) if us: result += ".%06d" % us return result -# Correctly substitute for %z and %Z escapes in strftime formats. def _wrap_strftime(object, format, timetuple): - # Don't call utcoffset() or tzname() unless actually needed. freplace = None # the string to use for %f zreplace = None # the string to use for %z Zreplace = None # the string to use for %Z - # Scan format for %z and %Z escapes, replacing as needed. newformat = [] push = newformat.append i, n = 0, len(format) @@ -192,10 +119,8 @@ def _wrap_strftime(object, format, timetuple): offset = -offset sign = '-' h, m = divmod(offset, timedelta(hours=1)) - assert not m % timedelta(minutes=1), "whole minute" m //= timedelta(minutes=1) zreplace = '%c%02d%02d' % (sign, h, m) - assert '%' not in zreplace newformat.append(zreplace) elif ch == 'Z': if Zreplace is None: @@ -227,14 +152,7 @@ def _check_tzname(name): raise TypeError("tzinfo.tzname() must return None or string, " "not '%s'" % type(name)) -# name is the offset-producing method, "utcoffset" or "dst". -# offset is what it returned. -# If offset isn't None or timedelta, raises TypeError. -# If offset is None, returns None. -# Else offset is checked for being in range, and a whole # of minutes. -# If it is, its integer value is returned. Else ValueError is raised. def _check_utc_offset(name, offset): - assert name in ("utcoffset", "dst") if offset is None: return if not isinstance(offset, timedelta): @@ -280,126 +198,59 @@ def _cmperror(x, y): type(x).__name__, type(y).__name__)) class timedelta: - """Represent the difference between two datetime objects. - - Supported operators: - - - add, subtract timedelta - - unary plus, minus, abs - - compare to timedelta - - multiply, divide by int - - In addition, datetime supports subtraction of two datetime objects - returning a timedelta, and addition or subtraction of a datetime - and a timedelta giving a datetime. - - Representation: (days, seconds, microseconds). Why? Because I - felt like it. - """ - __slots__ = '_days', '_seconds', '_microseconds' - def __new__(cls, days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0): - # Doing this efficiently and accurately in C is going to be difficult - # and error-prone, due to ubiquitous overflow possibilities, and that - # C double doesn't have enough bits of precision to represent - # microseconds over 10K years faithfully. The code here tries to make - # explicit where go-fast assumptions can be relied on, in order to - # guide the C implementation; it's way more convoluted than speed- - # ignoring auto-overflow-to-long idiomatic Python could be. - - # XXX Check that all inputs are ints or floats. - - # Final values, all integer. - # s and us fit in 32-bit signed ints; d isn't bounded. d = s = us = 0 - # Normalize everything to days, seconds, microseconds. days += weeks*7 seconds += minutes*60 + hours*3600 microseconds += milliseconds*1000 - # Get rid of all fractions, and normalize s and us. - # Take a deep breath . if isinstance(days, float): dayfrac, days = _math.modf(days) daysecondsfrac, daysecondswhole = _math.modf(dayfrac * (24.*3600.)) - assert daysecondswhole == int(daysecondswhole) # can't overflow s = int(daysecondswhole) - assert days == int(days) d = int(days) else: daysecondsfrac = 0.0 d = days - assert isinstance(daysecondsfrac, float) - assert abs(daysecondsfrac) <= 1.0 - assert isinstance(d, int) - assert abs(s) <= 24 * 3600 - # days isn't referenced again before redefinition if isinstance(seconds, float): secondsfrac, seconds = _math.modf(seconds) - assert seconds == int(seconds) seconds = int(seconds) secondsfrac += daysecondsfrac - assert abs(secondsfrac) <= 2.0 else: secondsfrac = daysecondsfrac - # daysecondsfrac isn't referenced again - assert isinstance(secondsfrac, float) - assert abs(secondsfrac) <= 2.0 - assert isinstance(seconds, int) days, seconds = divmod(seconds, 24*3600) d += days s += int(seconds) # can't overflow - assert isinstance(s, int) - assert abs(s) <= 2 * 24 * 3600 - # seconds isn't referenced again before redefinition usdouble = secondsfrac * 1e6 - assert abs(usdouble) < 2.1e6 # exact value not critical - # secondsfrac isn't referenced again if isinstance(microseconds, float): microseconds += usdouble microseconds = round(microseconds, 0) seconds, microseconds = divmod(microseconds, 1e6) - assert microseconds == int(microseconds) - assert seconds == int(seconds) days, seconds = divmod(seconds, 24.*3600.) - assert days == int(days) - assert seconds == int(seconds) d += int(days) s += int(seconds) # can't overflow - assert isinstance(s, int) - assert abs(s) <= 3 * 24 * 3600 else: seconds, microseconds = divmod(microseconds, 1000000) days, seconds = divmod(seconds, 24*3600) d += days s += int(seconds) # can't overflow - assert isinstance(s, int) - assert abs(s) <= 3 * 24 * 3600 microseconds = float(microseconds) microseconds += usdouble microseconds = round(microseconds, 0) - assert abs(s) <= 3 * 24 * 3600 - assert abs(microseconds) < 3.1e6 # Just a little bit of carrying possible for microseconds and seconds. - assert isinstance(microseconds, float) - assert int(microseconds) == microseconds us = int(microseconds) seconds, us = divmod(us, 1000000) s += seconds # cant't overflow - assert isinstance(s, int) days, s = divmod(s, 24*3600) d += days - assert isinstance(d, int) - assert isinstance(s, int) and 0 <= s < 24*3600 - assert isinstance(us, int) and 0 <= us < 1000000 self = object.__new__(cls) @@ -591,7 +442,6 @@ def __gt__(self, other): _cmperror(self, other) def _cmp(self, other): - assert isinstance(other, timedelta) return _cmp(self._getstate(), other._getstate()) def __hash__(self): @@ -817,7 +667,6 @@ def __gt__(self, other): return NotImplemented def _cmp(self, other): - assert isinstance(other, date) y, m, d = self._year, self._month, self._day y2, m2, d2 = other._year, other._month, other._day return _cmp((y, m, d), (y2, m2, d2)) @@ -957,59 +806,10 @@ def fromutc(self, dt): "results; cannot convert") return dt + dtdst - # Pickle support. - - def __reduce__(self): - getinitargs = getattr(self, "__getinitargs__", None) - if getinitargs: - args = getinitargs() - else: - args = () - getstate = getattr(self, "__getstate__", None) - if getstate: - state = getstate() - else: - state = getattr(self, "__dict__", None) or None - if state is None: - return (self.__class__, args) - else: - return (self.__class__, args, state) - _tzinfo_class = tzinfo class time: - """Time with time zone. - - Constructors: - - __new__() - - Operators: - - __repr__, __str__ - __cmp__, __hash__ - - Methods: - - strftime() - isoformat() - utcoffset() - tzname() - dst() - - Properties (readonly): - hour, minute, second, microsecond, tzinfo - """ - def __new__(cls, hour=0, minute=0, second=0, microsecond=0, tzinfo=None): - """Constructor. - - Arguments: - - hour, minute (required) - second, microsecond (default to zero) - tzinfo (default to None) - """ self = object.__new__(cls) if isinstance(hour, bytes) and len(hour) == 6: # Pickle support @@ -1027,27 +827,22 @@ def __new__(cls, hour=0, minute=0, second=0, microsecond=0, tzinfo=None): # Read-only field accessors @property def hour(self): - """hour (0-23)""" return self._hour @property def minute(self): - """minute (0-59)""" return self._minute @property def second(self): - """second (0-59)""" return self._second @property def microsecond(self): - """microsecond (0-999999)""" return self._microsecond @property def tzinfo(self): - """timezone info object""" return self._tzinfo # Standard conversions, __hash__ (and helpers) @@ -1091,7 +886,6 @@ def __gt__(self, other): _cmperror(self, other) def _cmp(self, other, allow_mixed=False): - assert isinstance(other, time) mytz = self._tzinfo ottz = other._tzinfo myoff = otoff = None @@ -1119,13 +913,11 @@ def _cmp(self, other, allow_mixed=False): (othhmm, other._second, other._microsecond)) def __hash__(self): - """Hash.""" tzoff = self.utcoffset() if not tzoff: # zero or None return hash(self._getstate()[0]) h, m = divmod(timedelta(hours=self.hour, minutes=self.minute) - tzoff, timedelta(hours=1)) - assert not m % timedelta(minutes=1), "whole minute" m //= timedelta(minutes=1) if 0 <= h < 24: return hash(time(h, m, self.second, self.microsecond)) @@ -1134,7 +926,6 @@ def __hash__(self): # Conversion to string def _tzstr(self, sep=":"): - """Return formatted timezone offset (+xx:xx) or None.""" off = self.utcoffset() if off is not None: if off.days < 0: @@ -1143,14 +934,11 @@ def _tzstr(self, sep=":"): else: sign = "+" hh, mm = divmod(off, timedelta(hours=1)) - assert not mm % timedelta(minutes=1), "whole minute" mm //= timedelta(minutes=1) - assert 0 <= hh < 24 off = "%s%02d%s%02d" % (sign, hh, sep, mm) return off def __repr__(self): - """Convert to formal string, for repr().""" if self._microsecond != 0: s = ", %d, %d" % (self._second, self._microsecond) elif self._second != 0: @@ -1160,16 +948,10 @@ def __repr__(self): s= "%s(%d, %d%s)" % ('datetime.' + self.__class__.__name__, self._hour, self._minute, s) if self._tzinfo is not None: - assert s[-1:] == ")" s = s[:-1] + ", tzinfo=%r" % self._tzinfo + ")" return s def isoformat(self): - """Return the time formatted according to ISO. - - This is 'HH:MM:SS.mmmmmm+zz:zz', or 'HH:MM:SS+zz:zz' if - self.microsecond == 0. - """ s = _format_time(self._hour, self._minute, self._second, self._microsecond) tz = self._tzstr() @@ -1180,11 +962,6 @@ def isoformat(self): __str__ = isoformat def strftime(self, fmt): - """Format using strftime(). The date part of the timestamp passed - to underlying strftime should not be used. - """ - # The year must be >= 1000 else Python's strftime implementation - # can raise a bogus exception. timetuple = (1900, 1, 1, self._hour, self._minute, self._second, 0, 1, -1) @@ -1198,8 +975,6 @@ def __format__(self, fmt): # Timezone functions def utcoffset(self): - """Return the timezone offset in minutes east of UTC (negative west of - UTC).""" if self._tzinfo is None: return None offset = self._tzinfo.utcoffset(None) @@ -1207,12 +982,6 @@ def utcoffset(self): return offset def tzname(self): - """Return the timezone name. - - Note that the name is 100% informational -- there's no requirement that - it mean anything in particular. For example, "GMT", "UTC", "-500", - "-5:00", "EDT", "US/Eastern", "America/New York" are all valid replies. - """ if self._tzinfo is None: return None name = self._tzinfo.tzname(None) @@ -1220,14 +989,6 @@ def tzname(self): return name def dst(self): - """Return 0 if DST is not in effect, or the DST offset (in minutes - eastward) if DST is in effect. - - This is purely informational; the DST offset has already been added to - the UTC offset returned by utcoffset() if applicable, so there's no - need to consult dst() unless you're interested in displaying the DST - info. - """ if self._tzinfo is None: return None offset = self._tzinfo.dst(None) @@ -1236,7 +997,6 @@ def dst(self): def replace(self, hour=None, minute=None, second=None, microsecond=None, tzinfo=True): - """Return a new time with new values for the specified fields.""" if hour is None: hour = self.hour if minute is None: @@ -1290,15 +1050,6 @@ def __reduce__(self): time.resolution = timedelta(microseconds=1) class datetime(date): - """datetime(year, month, day[, hour[, minute[, second[, microsecond[,tzinfo]]]]]) - - The year, month and day arguments are required. tzinfo may be None, or an - instance of a tzinfo subclass. The remaining arguments may be ints. - """ - - __slots__ = date.__slots__ + ( - '_hour', '_minute', '_second', - '_microsecond', '_tzinfo') def __new__(cls, year, month=None, day=None, hour=0, minute=0, second=0, microsecond=0, tzinfo=None): if isinstance(year, bytes) and len(year) == 10: @@ -1319,36 +1070,26 @@ def __new__(cls, year, month=None, day=None, hour=0, minute=0, second=0, # Read-only field accessors @property def hour(self): - """hour (0-23)""" return self._hour @property def minute(self): - """minute (0-59)""" return self._minute @property def second(self): - """second (0-59)""" return self._second @property def microsecond(self): - """microsecond (0-999999)""" return self._microsecond @property def tzinfo(self): - """timezone info object""" return self._tzinfo @classmethod def fromtimestamp(cls, t, tz=None): - """Construct a datetime from a POSIX timestamp (like time.time()). - - A timezone info object may be passed in as well. - """ - _check_tzinfo_arg(tz) converter = _time.localtime if tz is None else _time.gmtime @@ -1356,10 +1097,6 @@ def fromtimestamp(cls, t, tz=None): t, frac = divmod(t, 1.0) us = int(frac * 1e6) - # If timestamp is less than one microsecond smaller than a - # full second, us can be rounded up to 1000000. In this case, - # roll over to seconds, otherwise, ValueError is raised - # by the constructor. if us == 1000000: t += 1 us = 0 @@ -1461,7 +1198,6 @@ def timetz(self): def replace(self, year=None, month=None, day=None, hour=None, minute=None, second=None, microsecond=None, tzinfo=True): - """Return a new datetime with new values for the specified fields.""" if year is None: year = self.year if month is None: @@ -1541,17 +1277,6 @@ def ctime(self): self._year) def isoformat(self, sep='T'): - """Return the time formatted according to ISO. - - This is 'YYYY-MM-DD HH:MM:SS.mmmmmm', or 'YYYY-MM-DD HH:MM:SS' if - self.microsecond == 0. - - If self.tzinfo is not None, the UTC offset is also attached, giving - 'YYYY-MM-DD HH:MM:SS.mmmmmm+HH:MM' or 'YYYY-MM-DD HH:MM:SS+HH:MM'. - - Optional argument sep specifies the separator between date and - time, default 'T'. - """ s = ("%04d-%02d-%02d%s" % (self._year, self._month, self._day, sep) + _format_time(self._hour, self._minute, self._second, @@ -1564,13 +1289,11 @@ def isoformat(self, sep='T'): else: sign = "+" hh, mm = divmod(off, timedelta(hours=1)) - assert not mm % timedelta(minutes=1), "whole minute" mm //= timedelta(minutes=1) s += "%s%02d:%02d" % (sign, hh, mm) return s def __repr__(self): - """Convert to formal string, for repr().""" L = [self._year, self._month, self._day, # These are never zero self._hour, self._minute, self._second, self._microsecond] if L[-1] == 0: @@ -1580,7 +1303,6 @@ def __repr__(self): s = ", ".join(map(str, L)) s = "%s(%s)" % ('datetime.' + self.__class__.__name__, s) if self._tzinfo is not None: - assert s[-1:] == ")" s = s[:-1] + ", tzinfo=%r" % self._tzinfo + ")" return s @@ -1595,8 +1317,6 @@ def strptime(cls, date_string, format): return _strptime._strptime_datetime(cls, date_string, format) def utcoffset(self): - """Return the timezone offset in minutes east of UTC (negative west of - UTC).""" if self._tzinfo is None: return None offset = self._tzinfo.utcoffset(self) @@ -1604,25 +1324,11 @@ def utcoffset(self): return offset def tzname(self): - """Return the timezone name. - - Note that the name is 100% informational -- there's no requirement that - it mean anything in particular. For example, "GMT", "UTC", "-500", - "-5:00", "EDT", "US/Eastern", "America/New York" are all valid replies. - """ name = _call_tzinfo_method(self._tzinfo, "tzname", self) _check_tzname(name) return name def dst(self): - """Return 0 if DST is not in effect, or the DST offset (in minutes - eastward) if DST is in effect. - - This is purely informational; the DST offset has already been added to - the UTC offset returned by utcoffset() if applicable, so there's no - need to consult dst() unless you're interested in displaying the DST - info. - """ if self._tzinfo is None: return None offset = self._tzinfo.dst(self) @@ -1680,7 +1386,6 @@ def __gt__(self, other): _cmperror(self, other) def _cmp(self, other, allow_mixed=False): - assert isinstance(other, datetime) mytz = self._tzinfo ottz = other._tzinfo myoff = otoff = None @@ -1839,7 +1544,6 @@ def _create(cls, offset, name=None): return self def __getinitargs__(self): - """pickle support""" if self._name is None: return (self._offset,) return (self._offset, self._name) @@ -1853,15 +1557,6 @@ def __hash__(self): return hash(self._offset) def __repr__(self): - """Convert to formal string, for repr(). - - >>> tz = timezone.utc - >>> repr(tz) - 'datetime.timezone.utc' - >>> tz = timezone(timedelta(hours=-5), 'EST') - >>> repr(tz) - "datetime.timezone(datetime.timedelta(-1, 68400), 'EST')" - """ if self is self.utc: return 'datetime.timezone.utc' if self._name is None: @@ -1920,219 +1615,3 @@ def _name_from_offset(delta): timezone.min = timezone._create(timezone._minoffset) timezone.max = timezone._create(timezone._maxoffset) _EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) -""" -Some time zone algebra. For a datetime x, let - x.n = x stripped of its timezone -- its naive time. - x.o = x.utcoffset(), and assuming that doesn't raise an exception or - return None - x.d = x.dst(), and assuming that doesn't raise an exception or - return None - x.s = x's standard offset, x.o - x.d - -Now some derived rules, where k is a duration (timedelta). - -1. x.o = x.s + x.d - This follows from the definition of x.s. - -2. If x and y have the same tzinfo member, x.s = y.s. - This is actually a requirement, an assumption we need to make about - sane tzinfo classes. - -3. The naive UTC time corresponding to x is x.n - x.o. - This is again a requirement for a sane tzinfo class. - -4. (x+k).s = x.s - This follows from #2, and that datimetimetz+timedelta preserves tzinfo. - -5. (x+k).n = x.n + k - Again follows from how arithmetic is defined. - -Now we can explain tz.fromutc(x). Let's assume it's an interesting case -(meaning that the various tzinfo methods exist, and don't blow up or return -None when called). - -The function wants to return a datetime y with timezone tz, equivalent to x. -x is already in UTC. - -By #3, we want - - y.n - y.o = x.n [1] - -The algorithm starts by attaching tz to x.n, and calling that y. So -x.n = y.n at the start. Then it wants to add a duration k to y, so that [1] -becomes true; in effect, we want to solve [2] for k: - - (y+k).n - (y+k).o = x.n [2] - -By #1, this is the same as - - (y+k).n - ((y+k).s + (y+k).d) = x.n [3] - -By #5, (y+k).n = y.n + k, which equals x.n + k because x.n=y.n at the start. -Substituting that into [3], - - x.n + k - (y+k).s - (y+k).d = x.n; the x.n terms cancel, leaving - k - (y+k).s - (y+k).d = 0; rearranging, - k = (y+k).s - (y+k).d; by #4, (y+k).s == y.s, so - k = y.s - (y+k).d - -On the RHS, (y+k).d can't be computed directly, but y.s can be, and we -approximate k by ignoring the (y+k).d term at first. Note that k can't be -very large, since all offset-returning methods return a duration of magnitude -less than 24 hours. For that reason, if y is firmly in std time, (y+k).d must -be 0, so ignoring it has no consequence then. - -In any case, the new value is - - z = y + y.s [4] - -It's helpful to step back at look at [4] from a higher level: it's simply -mapping from UTC to tz's standard time. - -At this point, if - - z.n - z.o = x.n [5] - -we have an equivalent time, and are almost done. The insecurity here is -at the start of daylight time. Picture US Eastern for concreteness. The wall -time jumps from 1:59 to 3:00, and wall hours of the form 2:MM don't make good -sense then. The docs ask that an Eastern tzinfo class consider such a time to -be EDT (because it's "after 2"), which is a redundant spelling of 1:MM EST -on the day DST starts. We want to return the 1:MM EST spelling because that's -the only spelling that makes sense on the local wall clock. - -In fact, if [5] holds at this point, we do have the standard-time spelling, -but that takes a bit of proof. We first prove a stronger result. What's the -difference between the LHS and RHS of [5]? Let - - diff = x.n - (z.n - z.o) [6] - -Now - z.n = by [4] - (y + y.s).n = by #5 - y.n + y.s = since y.n = x.n - x.n + y.s = since z and y are have the same tzinfo member, - y.s = z.s by #2 - x.n + z.s - -Plugging that back into [6] gives - - diff = - x.n - ((x.n + z.s) - z.o) = expanding - x.n - x.n - z.s + z.o = cancelling - - z.s + z.o = by #2 - z.d - -So diff = z.d. - -If [5] is true now, diff = 0, so z.d = 0 too, and we have the standard-time -spelling we wanted in the endcase described above. We're done. Contrarily, -if z.d = 0, then we have a UTC equivalent, and are also done. - -If [5] is not true now, diff = z.d != 0, and z.d is the offset we need to -add to z (in effect, z is in tz's standard time, and we need to shift the -local clock into tz's daylight time). - -Let - - z' = z + z.d = z + diff [7] - -and we can again ask whether - - z'.n - z'.o = x.n [8] - -If so, we're done. If not, the tzinfo class is insane, according to the -assumptions we've made. This also requires a bit of proof. As before, let's -compute the difference between the LHS and RHS of [8] (and skipping some of -the justifications for the kinds of substitutions we've done several times -already): - - diff' = x.n - (z'.n - z'.o) = replacing z'.n via [7] - x.n - (z.n + diff - z'.o) = replacing diff via [6] - x.n - (z.n + x.n - (z.n - z.o) - z'.o) = - x.n - z.n - x.n + z.n - z.o + z'.o = cancel x.n - - z.n + z.n - z.o + z'.o = cancel z.n - - z.o + z'.o = #1 twice - -z.s - z.d + z'.s + z'.d = z and z' have same tzinfo - z'.d - z.d - -So z' is UTC-equivalent to x iff z'.d = z.d at this point. If they are equal, -we've found the UTC-equivalent so are done. In fact, we stop with [7] and -return z', not bothering to compute z'.d. - -How could z.d and z'd differ? z' = z + z.d [7], so merely moving z' by -a dst() offset, and starting *from* a time already in DST (we know z.d != 0), -would have to change the result dst() returns: we start in DST, and moving -a little further into it takes us out of DST. - -There isn't a sane case where this can happen. The closest it gets is at -the end of DST, where there's an hour in UTC with no spelling in a hybrid -tzinfo class. In US Eastern, that's 5:MM UTC = 0:MM EST = 1:MM EDT. During -that hour, on an Eastern clock 1:MM is taken as being in standard time (6:MM -UTC) because the docs insist on that, but 0:MM is taken as being in daylight -time (4:MM UTC). There is no local time mapping to 5:MM UTC. The local -clock jumps from 1:59 back to 1:00 again, and repeats the 1:MM hour in -standard time. Since that's what the local clock *does*, we want to map both -UTC hours 5:MM and 6:MM to 1:MM Eastern. The result is ambiguous -in local time, but so it goes -- it's the way the local clock works. - -When x = 5:MM UTC is the input to this algorithm, x.o=0, y.o=-5 and y.d=0, -so z=0:MM. z.d=60 (minutes) then, so [5] doesn't hold and we keep going. -z' = z + z.d = 1:MM then, and z'.d=0, and z'.d - z.d = -60 != 0 so [8] -(correctly) concludes that z' is not UTC-equivalent to x. - -Because we know z.d said z was in daylight time (else [5] would have held and -we would have stopped then), and we know z.d != z'.d (else [8] would have held -and we have stopped then), and there are only 2 possible values dst() can -return in Eastern, it follows that z'.d must be 0 (which it is in the example, -but the reasoning doesn't depend on the example -- it depends on there being -two possible dst() outcomes, one zero and the other non-zero). Therefore -z' must be in standard time, and is the spelling we want in this case. - -Note again that z' is not UTC-equivalent as far as the hybrid tzinfo class is -concerned (because it takes z' as being in standard time rather than the -daylight time we intend here), but returning it gives the real-life "local -clock repeats an hour" behavior when mapping the "unspellable" UTC hour into -tz. - -When the input is 6:MM, z=1:MM and z.d=0, and we stop at once, again with -the 1:MM standard time spelling we want. - -So how can this break? One of the assumptions must be violated. Two -possibilities: - -1) [2] effectively says that y.s is invariant across all y belong to a given - time zone. This isn't true if, for political reasons or continental drift, - a region decides to change its base offset from UTC. - -2) There may be versions of "double daylight" time where the tail end of - the analysis gives up a step too early. I haven't thought about that - enough to say. - -In any case, it's clear that the default fromutc() is strong enough to handle -"almost all" time zones: so long as the standard offset is invariant, it -doesn't matter if daylight time transition points change from year to year, or -if daylight time is skipped in some years; it doesn't matter how large or -small dst() may get within its bounds; and it doesn't even matter if some -perverse time zone returns a negative dst()). So a breaking case must be -pretty bizarre, and a tzinfo subclass can override fromutc() if it is. -""" -try: - from _datetime import * -except ImportError: - pass -else: - # Clean up unused names - del (_DAYNAMES, _DAYS_BEFORE_MONTH, _DAYS_IN_MONTH, - _DI100Y, _DI400Y, _DI4Y, _MAXORDINAL, _MONTHNAMES, - _build_struct_time, _call_tzinfo_method, _check_date_fields, - _check_time_fields, _check_tzinfo_arg, _check_tzname, - _check_utc_offset, _cmp, _cmperror, _date_class, _days_before_month, - _days_before_year, _days_in_month, _format_time, _is_leap, - _isoweek1monday, _math, _ord2ymd, _time, _time_class, _tzinfo_class, - _wrap_strftime, _ymd2ord) - # XXX Since import * above excludes names that start with _, - # docstring does not get overwritten. In the future, it may be - # appropriate to maintain a single module level docstring and - # remove the following line. - from _datetime import __doc__ From 6e101a9a488fdab9ec7884d29f5dffda87694651 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 27 Jun 2022 13:29:37 +0200 Subject: [PATCH 58/65] sched stuff --- sched/sched.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/sched/sched.py b/sched/sched.py index e69de29bb2..85651d7e1e 100644 --- a/sched/sched.py +++ b/sched/sched.py @@ -0,0 +1,109 @@ +# A generally useful event scheduler class. +# +# Scheduling runs in the background, courtesy of 'machine.schedule'. +# +# Events are specified by tuples (time, action, *argument, **kwargs). +# Unlike the standard sched class there's no priority, no runner +# function, and no locking. + +from time import ticks_ms, ticks_diff +import heapq +from collections import namedtuple +from machine import Timer as _Timer +from micropython import schedule as _sched + +def _cmp(s,o): +# if s.time == o.time: +# return s.priority - o.priority + return ticks_diff(s.time, o.time) + +class Event(namedtuple('Event', 'time action argument kwargs')): + def __eq__(s, o): return _cmp(s,o) == 0 + def __lt__(s, o): return _cmp(s,o) < 0 + def __le__(s, o): return _cmp(s,o) <= 0 + def __gt__(s, o): return _cmp(s,o) > 0 + def __ge__(s, o): return _cmp(s,o) >= 0 + +class Scheduler: + _running = False + def __init__(self, timer=-1): + # setup. Pass in the timer# to use if not virtual. + self._queue = [] + self._timer = _Timer(timer) + self._run_ = self._run + self._sched_run_ = self._sched_run + + def enter(self, _d_, _a_, *_a, **_kw): + # delay in ms, proc, *args, **kw. + q = self._queue + t = ticks_ms() + event = Event(t + _d_, _a_, _a, _kw) + heapq.heappush(q, event) + self._set_timer(t) + return event + + def cancel(self, event): + # Remove an event from the queue. + self._queue.remove(event) + heapq.heapify(self._queue) + + def _set_timer(self, t=None): + if self._running: + return + if not self._queue: + self._timer.deinit() + return + + if t is None: + t = ticks_ms() + t = ticks_diff(self._queue[0].time, t) + if t > 0: + self._timer.init(mode=_Timer.ONE_SHOT, period=t, callback=self._sched_run_) + else: + _sched(self._run, None) + + def _sched_run(self, _): + # runs in IRQ. + try: + _sched(self._run_, None) + except Exception: + self._timer.init(mode=Timer.ONE_SHOT, period=1, callback=self._sched_run_) + + def cancel(self, event): + # Remove an event from the queue. + self._queue.remove(event) + heapq.heapify(self._queue) + + def empty(self): + # Check whether the queue is empty.""" + return not self._queue + + def _run(self, _): + # Execute events until the queue is empty. + self._running = True + try: + q = self._queue + pop = heapq.heappop + while q: + t, action, argument, kwargs = q[0] + now = ticks_ms() + if ticks_diff(t,now) > 0: + break + else: + pop(q) + action(*argument, **kwargs) + finally: + self._running = False + self._set_timer() + + @property + def queue(self): + # An ordered list of upcoming events. + events = self._queue[:] + while events: + yield heapq.heappop(events) + + def dump(self): + print("now:",ticks_ms()) + for e in self.queue: + print(e) From 2cecf2339053295fb0b33513cdd6d2b16518cf74 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Mon, 18 Jul 2022 13:18:13 +0200 Subject: [PATCH 59/65] Fix "outcome" port --- outcome/outcome/__init__.py | 5 +---- outcome/outcome/_async.py | 13 +------------ outcome/outcome/_sync.py | 25 ------------------------- outcome/outcome/_util.py | 20 -------------------- outcome/outcome/_version.py | 2 +- 5 files changed, 3 insertions(+), 62 deletions(-) diff --git a/outcome/outcome/__init__.py b/outcome/outcome/__init__.py index 6f93754757..e95f9cf0bd 100644 --- a/outcome/outcome/__init__.py +++ b/outcome/outcome/__init__.py @@ -1,10 +1,9 @@ # coding: utf-8 """Top-level package for outcome.""" -from __future__ import absolute_import, division, print_function import sys -from ._util import AlreadyUsedError, fixup_module_metadata +from ._util import AlreadyUsedError from ._version import __version__ if sys.version_info >= (3, 5): @@ -16,5 +15,3 @@ from ._sync import Error, Outcome, Value, capture __all__ = ('Error', 'Outcome', 'Value', 'capture', 'AlreadyUsedError') -fixup_module_metadata(__name__, globals()) -del fixup_module_metadata diff --git a/outcome/outcome/_async.py b/outcome/outcome/_async.py index 74e15586a5..741bf4b82d 100644 --- a/outcome/outcome/_async.py +++ b/outcome/outcome/_async.py @@ -1,4 +1,3 @@ -import abc from ._sync import Error as ErrorBase from ._sync import Outcome as OutcomeBase @@ -35,17 +34,7 @@ async def acapture(async_fn, *args, **kwargs): class Outcome(OutcomeBase): - @abc.abstractmethod - async def asend(self, agen): - """Send or throw the contained value or exception into the given async - generator object. - - Args: - agen: An async generator object supporting ``.asend()`` and - ``.athrow()`` methods. - - """ - + pass class Value(Outcome, ValueBase): async def asend(self, agen): diff --git a/outcome/outcome/_sync.py b/outcome/outcome/_sync.py index f69c4d0eaa..d3500b3dd5 100644 --- a/outcome/outcome/_sync.py +++ b/outcome/outcome/_sync.py @@ -1,7 +1,4 @@ # coding: utf-8 -from __future__ import absolute_import, division, print_function - -import abc import attr @@ -49,28 +46,6 @@ def _set_unwrapped(self): raise AlreadyUsedError self._unwrapped = True - @abc.abstractmethod - def unwrap(self): - """Return or raise the contained value or exception. - - These two lines of code are equivalent:: - - x = fn(*args) - x = outcome.capture(fn, *args).unwrap() - - """ - - @abc.abstractmethod - def send(self, gen): - """Send or throw the contained value or exception into the given - generator object. - - Args: - gen: A generator object supporting ``.send()`` and ``.throw()`` - methods. - - """ - @attr.s(frozen=True, repr=False, slots=True, init=False) class Value(Outcome): diff --git a/outcome/outcome/_util.py b/outcome/outcome/_util.py index 91b1d6c091..23741dd59a 100644 --- a/outcome/outcome/_util.py +++ b/outcome/outcome/_util.py @@ -1,27 +1,7 @@ # coding: utf-8 -from __future__ import absolute_import, division, print_function - -import abc -import sys - class AlreadyUsedError(RuntimeError): """An Outcome can only be unwrapped once.""" pass -def fixup_module_metadata(module_name, namespace): - def fix_one(obj): - mod = getattr(obj, "__module__", None) - if mod is not None and mod.startswith("outcome."): - obj.__module__ = module_name - if isinstance(obj, type): - for k in dir(obj): - attr_value = getattr(obj, k) - fix_one(attr_value) - - for objname in namespace["__all__"]: - obj = namespace[objname] - fix_one(obj) - - diff --git a/outcome/outcome/_version.py b/outcome/outcome/_version.py index 1121db00bf..6d7439d5f1 100644 --- a/outcome/outcome/_version.py +++ b/outcome/outcome/_version.py @@ -1,4 +1,4 @@ # coding: utf-8 # This file is imported from __init__.py and exec'd from setup.py -__version__ = "1.0.0" +__version__ = "1.0.1" From 2acb3df02767bd1d46b18038e769e4366337dbb7 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Thu, 16 Mar 2023 16:20:18 +0100 Subject: [PATCH 60/65] Formatting woes. --- collections.deque/collections/deque.py | 1 - collections/collections/__init__.py | 8 ++++++++ contextlib/contextlib.py | 17 +++++++++++++++-- pprint/pprint.py | 1 + uasyncio.queues/uasyncio/queues.py | 2 ++ ucontextlib/ucontextlib.py | 5 +++++ 6 files changed, 31 insertions(+), 3 deletions(-) diff --git a/collections.deque/collections/deque.py b/collections.deque/collections/deque.py index bb9f015c17..0a0bb7a31c 100644 --- a/collections.deque/collections/deque.py +++ b/collections.deque/collections/deque.py @@ -1,5 +1,4 @@ class deque: - def __init__(self, iterable=None): if iterable is None: self.q = [] diff --git a/collections/collections/__init__.py b/collections/collections/__init__.py index c566a8abd8..d0ed3d604f 100644 --- a/collections/collections/__init__.py +++ b/collections/collections/__init__.py @@ -4,6 +4,7 @@ # This is going to be just import-all for other modules in a namespace package import ucollections from ucollections import OrderedDict + try: from .defaultdict import defaultdict except ImportError: @@ -13,12 +14,15 @@ except ImportError: pass + class Mapping: pass + class MutableMapping: pass + def namedtuple(name, fields): _T = ucollections.namedtuple(name, fields) @@ -30,14 +34,18 @@ def _make(cls, seq): return t + class Sequence: pass + class MutableSequence: pass + class Set: pass + class MutableSet: pass diff --git a/contextlib/contextlib.py b/contextlib/contextlib.py index 57d00aa724..751cb5500b 100644 --- a/contextlib/contextlib.py +++ b/contextlib/contextlib.py @@ -9,9 +9,10 @@ import sys from collections import deque -from ucontextlib import * from functools import wraps +from ucontextlib import * + class AbstractAsyncContextManager(object): "compabilibty" @@ -35,10 +36,13 @@ class closing(object): f.close() """ + def __init__(self, thing): self.thing = thing + def __enter__(self): return self.thing + def __exit__(self, *exc_info): self.thing.close() @@ -77,6 +81,7 @@ class _AsyncGeneratorContextManager(AbstractAsyncContextManager): def __init__(self, func, args, kwds): self.gen = func(*args, **kwds) self.func, self.args, self.kwds = func, args, kwds + async def __aenter__(self): try: return await self.gen.__anext__() @@ -146,9 +151,11 @@ async def some_async_generator(): finally: """ + @wraps(func) def helper(*args, **kwds): return _AsyncGeneratorContextManager(func, args, kwds) + return helper @@ -163,6 +170,7 @@ def _create_exit_wrapper(cm, cm_exit): def _create_cb_wrapper(callback, *args, **kwds): def _exit_wrapper(exc_type, exc, tb): callback(*args, **kwds) + return _exit_wrapper def __init__(self): @@ -230,6 +238,7 @@ def _push_cm_exit(self, cm, cm_exit): def _push_exit_callback(self, callback, is_sync=True): self._exit_callbacks.append((is_sync, callback)) + # Inspired by discussions on http://bugs.python.org/issue13585 class ExitStack(_BaseExitStack, AbstractContextManager): """Context manager for dynamic management of a stack of exit callbacks. @@ -251,6 +260,7 @@ def __exit__(self, *exc_details): # We manipulate the exception state so it behaves as though # we were actually nesting multiple with statements frame_exc = sys.exc_info()[1] + def _fix_exception_context(new_exc, old_exc): # Context may not be correct, so find the end of the chain while 1: @@ -298,6 +308,7 @@ def close(self): """Immediately unwind the context stack.""" self.__exit__(None, None, None) + # Inspired by discussions on https://bugs.python.org/issue29302 class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager): """Async context manager for dynamic management of a stack of exit @@ -320,6 +331,7 @@ def _create_async_exit_wrapper(cm, cm_exit): def _create_async_cb_wrapper(callback, *args, **kwds): async def _exit_wrapper(exc_type, exc, tb): await callback(*args, **kwds) + return _exit_wrapper async def enter_async_context(self, cm): @@ -384,6 +396,7 @@ async def __aexit__(self, *exc_details): # We manipulate the exception state so it behaves as though # we were actually nesting multiple with statements frame_exc = sys.exc_info()[1] + def _fix_exception_context(new_exc, old_exc): # Context may not be correct, so find the end of the chain while 1: @@ -431,6 +444,7 @@ def _fix_exception_context(new_exc, old_exc): raise return received_exc and suppressed_exc + class nullcontext(AbstractContextManager): """Context manager that does no additional processing. @@ -450,4 +464,3 @@ def __enter__(self): def __exit__(self, *excinfo): pass - diff --git a/pprint/pprint.py b/pprint/pprint.py index e4b09183ef..bfbc84d3ae 100644 --- a/pprint/pprint.py +++ b/pprint/pprint.py @@ -1,4 +1,5 @@ import sys + import uio diff --git a/uasyncio.queues/uasyncio/queues.py b/uasyncio.queues/uasyncio/queues.py index 04918ae5c1..856f6d9882 100644 --- a/uasyncio.queues/uasyncio/queues.py +++ b/uasyncio.queues/uasyncio/queues.py @@ -1,4 +1,5 @@ from collections.deque import deque + from uasyncio.core import sleep @@ -21,6 +22,7 @@ class Queue: with qsize(), since your single-threaded uasyncio application won't be interrupted between calling qsize() and doing an operation on the Queue. """ + _attempt_delay = 0.1 def __init__(self, maxsize=0): diff --git a/ucontextlib/ucontextlib.py b/ucontextlib/ucontextlib.py index e7a02d8c1f..cb932ee29a 100644 --- a/ucontextlib/ucontextlib.py +++ b/ucontextlib/ucontextlib.py @@ -9,10 +9,12 @@ - supress """ + class AbstractContextManager(object): "Compatibility" pass + class ContextDecorator(AbstractContextManager): "A base class or mixin that enables context managers to work as decorators." @@ -32,6 +34,7 @@ def __call__(self, func): def inner(*args, **kwds): with self._recreate_cm(): return func(*args, **kwds) + return inner @@ -105,6 +108,8 @@ def some_generator(): """ + def helper(*args, **kwds): return _GeneratorContextManager(func, *args, **kwds) + return helper From df4f856621c8ceba42c9f50f4b633b73ada17109 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Wed, 12 Apr 2023 15:31:27 +0200 Subject: [PATCH 61/65] uasyncio.queues: Use taskqueues instead of busy-waiting --- uasyncio.queues/uasyncio/queues.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/uasyncio.queues/uasyncio/queues.py b/uasyncio.queues/uasyncio/queues.py index 856f6d9882..992466e207 100644 --- a/uasyncio.queues/uasyncio/queues.py +++ b/uasyncio.queues/uasyncio/queues.py @@ -1,6 +1,6 @@ from collections.deque import deque -from uasyncio.core import sleep +from uasyncio import core class QueueEmpty(Exception): @@ -23,14 +23,17 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - _attempt_delay = 0.1 - def __init__(self, maxsize=0): self.maxsize = maxsize self._queue = deque() + self._full = core.TaskQueue() + self._empty = core.TaskQueue() def _get(self): - return self._queue.popleft() + res = self._queue.popleft() + if self._full.peek(): + core._task_queue.push(self._full.pop()) + return res def get(self): """Returns generator, which can be used for getting (and removing) @@ -40,8 +43,11 @@ def get(self): item = yield from queue.get() """ - while not self._queue: - yield from sleep(self._attempt_delay) + if not self._queue: + self._empty.push(core.cur_task) + core.cur_task.data = self._empty + yield + return self._get() def get_nowait(self): @@ -54,6 +60,8 @@ def get_nowait(self): return self._get() def _put(self, val): + if self._empty.peek(): + core._task_queue.push(self._empty.pop()) self._queue.append(val) def put(self, val): @@ -63,8 +71,10 @@ def put(self, val): yield from queue.put(item) """ - while self.qsize() >= self.maxsize and self.maxsize: - yield from sleep(self._attempt_delay) + if self.maxsize and self.qsize() >= self.maxsize: + self._full.push(core.cur_task) + core.cur_task.data = self._full + yield self._put(val) def put_nowait(self, val): From f9739fc5ce10160d33dcc34f29a71f1fe74d6599 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Wed, 12 Apr 2023 15:32:02 +0200 Subject: [PATCH 62/65] Don't talk about "yield from" --- uasyncio.queues/uasyncio/queues.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uasyncio.queues/uasyncio/queues.py b/uasyncio.queues/uasyncio/queues.py index 992466e207..7218099d53 100644 --- a/uasyncio.queues/uasyncio/queues.py +++ b/uasyncio.queues/uasyncio/queues.py @@ -41,7 +41,7 @@ def get(self): Usage:: - item = yield from queue.get() + item = await queue.get() """ if not self._queue: self._empty.push(core.cur_task) @@ -69,7 +69,7 @@ def put(self, val): Usage:: - yield from queue.put(item) + await queue.put(item) """ if self.maxsize and self.qsize() >= self.maxsize: self._full.push(core.cur_task) From 33382b68c0053d3fdfb2b40c48d8063d468c79d4 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Wed, 12 Apr 2023 15:32:13 +0200 Subject: [PATCH 63/65] Do the easy test first --- uasyncio.queues/uasyncio/queues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uasyncio.queues/uasyncio/queues.py b/uasyncio.queues/uasyncio/queues.py index 7218099d53..53c20fa7a6 100644 --- a/uasyncio.queues/uasyncio/queues.py +++ b/uasyncio.queues/uasyncio/queues.py @@ -82,7 +82,7 @@ def put_nowait(self, val): If no free slot is immediately available, raise QueueFull. """ - if self.qsize() >= self.maxsize and self.maxsize: + if self.maxsize and self.qsize() >= self.maxsize: raise QueueFull() self._put(val) From a7d0f05731af988ac4457d4d370df3f72a5022ec Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Wed, 12 Apr 2023 16:51:28 +0200 Subject: [PATCH 64/65] simplified pathlib --- pathlib/pathlib.py | 234 +++------------------------------------------ 1 file changed, 15 insertions(+), 219 deletions(-) diff --git a/pathlib/pathlib.py b/pathlib/pathlib.py index 1f2ccfa9d9..888df6a7a6 100644 --- a/pathlib/pathlib.py +++ b/pathlib/pathlib.py @@ -17,20 +17,11 @@ supports_symlinks = True -if os.name == 'nt': - import nt - if sys.getwindowsversion()[:2] >= (6, 0): - from nt import _getfinalpathname - else: - supports_symlinks = False - _getfinalpathname = None -else: - nt = None __all__ = [ - "PurePath", "PurePosixPath", "PureWindowsPath", - "Path", "PosixPath", "WindowsPath", + "PurePath", "PurePosixPath", + "Path", "PosixPath", ] # @@ -119,178 +110,13 @@ def join_parsed_parts(self, drv, root, parts, drv2, root2, parts2): return drv2, root2, parts2 -class _WindowsFlavour(_Flavour): - # Reference for Windows paths can be found at - # http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx - - sep = '\\' - altsep = '/' - has_drv = True - pathmod = ntpath - - is_supported = (os.name == 'nt') - - drive_letters = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') - ext_namespace_prefix = '\\\\?\\' - - reserved_names = ( - {'CON', 'PRN', 'AUX', 'NUL'} | - {'COM%d' % i for i in range(1, 10)} | - {'LPT%d' % i for i in range(1, 10)} - ) - - # Interesting findings about extended paths: - # - '\\?\c:\a', '//?/c:\a' and '//?/c:/a' are all supported - # but '\\?\c:/a' is not - # - extended paths are always absolute; "relative" extended paths will - # fail. - - def splitroot(self, part, sep=sep): - first = part[0:1] - second = part[1:2] - if (second == sep and first == sep): - # XXX extended paths should also disable the collapsing of "." - # components (according to MSDN docs). - prefix, part = self._split_extended_path(part) - first = part[0:1] - second = part[1:2] - else: - prefix = '' - third = part[2:3] - if (second == sep and first == sep and third != sep): - # is a UNC path: - # vvvvvvvvvvvvvvvvvvvvv root - # \\machine\mountpoint\directory\etc\... - # directory ^^^^^^^^^^^^^^ - index = part.find(sep, 2) - if index != -1: - index2 = part.find(sep, index + 1) - # a UNC path can't have two slashes in a row - # (after the initial two) - if index2 != index + 1: - if index2 == -1: - index2 = len(part) - if prefix: - return prefix + part[1:index2], sep, part[index2+1:] - else: - return part[:index2], sep, part[index2+1:] - drv = root = '' - if second == ':' and first in self.drive_letters: - drv = part[:2] - part = part[2:] - first = third - if first == sep: - root = first - part = part.lstrip(sep) - return prefix + drv, root, part - - def casefold(self, s): - return s.lower() - - def casefold_parts(self, parts): - return [p.lower() for p in parts] - - def compile_pattern(self, pattern): - return re.compile(fnmatch.translate(pattern), re.IGNORECASE).fullmatch - - def resolve(self, path, strict=False): - s = str(path) - if not s: - return os.getcwd() - previous_s = None - if _getfinalpathname is not None: - if strict: - return self._ext_to_normal(_getfinalpathname(s)) - else: - tail_parts = [] # End of the path after the first one not found - while True: - try: - s = self._ext_to_normal(_getfinalpathname(s)) - except FileNotFoundError: - previous_s = s - s, tail = os.path.split(s) - tail_parts.append(tail) - if previous_s == s: - return path - else: - return os.path.join(s, *reversed(tail_parts)) - # Means fallback on absolute - return None - - def _split_extended_path(self, s, ext_prefix=ext_namespace_prefix): - prefix = '' - if s.startswith(ext_prefix): - prefix = s[:4] - s = s[4:] - if s.startswith('UNC\\'): - prefix += s[:3] - s = '\\' + s[3:] - return prefix, s - - def _ext_to_normal(self, s): - # Turn back an extended path into a normal DOS-like path - return self._split_extended_path(s)[1] - - def is_reserved(self, parts): - # NOTE: the rules for reserved names seem somewhat complicated - # (e.g. r"..\NUL" is reserved but not r"foo\NUL"). - # We err on the side of caution and return True for paths which are - # not considered reserved by Windows. - if not parts: - return False - if parts[0].startswith('\\\\'): - # UNC paths are never reserved - return False - return parts[-1].partition('.')[0].upper() in self.reserved_names - - def make_uri(self, path): - # Under Windows, file URIs use the UTF-8 encoding. - drive = path.drive - if len(drive) == 2 and drive[1] == ':': - # It's a path on a local drive => 'file:///c:/a/b' - rest = path.as_posix()[2:].lstrip('/') - return 'file:///%s/%s' % ( - drive, urlquote_from_bytes(rest.encode('utf-8'))) - else: - # It's a path on a network drive => 'file://host/share/a/b' - return 'file:' + urlquote_from_bytes(path.as_posix().encode('utf-8')) - - def gethomedir(self, username): - if 'USERPROFILE' in os.environ: - userhome = os.environ['USERPROFILE'] - elif 'HOMEPATH' in os.environ: - try: - drv = os.environ['HOMEDRIVE'] - except KeyError: - drv = '' - userhome = drv + os.environ['HOMEPATH'] - else: - raise RuntimeError("Can't determine home directory") - - if username: - # Try to guess user home directory. By default all users - # directories are located in the same place and are named by - # corresponding usernames. If current user home directory points - # to nonstandard place, this guess is likely wrong. - if os.environ['USERNAME'] != username: - drv, root, parts = self.parse_parts((userhome,)) - if parts[-1] != os.environ['USERNAME']: - raise RuntimeError("Can't determine home directory " - "for %r" % username) - parts[-1] = username - if drv or root: - userhome = drv + root + self.join(parts[1:]) - else: - userhome = self.join(parts) - return userhome - class _PosixFlavour(_Flavour): sep = '/' altsep = '' has_drv = False pathmod = posixpath - is_supported = (os.name != 'nt') + is_supported = True def splitroot(self, part, sep=sep): if part and part[0] == sep: @@ -449,17 +275,10 @@ def replace(self, src, dst): mkdir = os.mkdir - if nt: - if supports_symlinks: - symlink = os.symlink - else: - def symlink(a, b, target_is_directory): - raise NotImplementedError("symlink() not available on this system") - else: - # Under POSIX, os.symlink() takes two args - @staticmethod - def symlink(a, b, target_is_directory): - return os.symlink(a, b) + # Under POSIX, os.symlink() takes two args + @staticmethod + def symlink(a, b, target_is_directory): + return os.symlink(a, b) utime = os.utime @@ -660,10 +479,9 @@ class PurePath(object): """Base class for manipulating paths without I/O. PurePath represents a filesystem path and offers operations which - don't imply any actual filesystem I/O. Depending on your system, - instantiating a PurePath will return either a PurePosixPath or a - PureWindowsPath object. You can also instantiate either of these classes - directly, regardless of your system. + don't imply any actual filesystem I/O. Instantiating a PurePath will + return a PurePosixPath object. You can also instantiate either of + these classes directly, regardless of your system. """ __slots__ = ( '_drv', '_root', '_parts', @@ -677,7 +495,7 @@ def __new__(cls, *args): new PurePath object. """ if cls is PurePath: - cls = PureWindowsPath if os.name == 'nt' else PurePosixPath + cls = PurePosixPath return cls._from_parts(args) def __reduce__(self): @@ -1059,16 +877,6 @@ class PurePosixPath(PurePath): __slots__ = () -class PureWindowsPath(PurePath): - """PurePath subclass for Windows systems. - - On a Windows system, instantiating a PurePath should return this object. - However, you can also instantiate it directly on any system. - """ - _flavour = _windows_flavour - __slots__ = () - - # Filesystem-accessing classes @@ -1076,10 +884,9 @@ class Path(PurePath): """PurePath subclass that can make system calls. Path represents a filesystem path but unlike PurePath, also offers - methods to do system calls on path objects. Depending on your system, - instantiating a Path will return either a PosixPath or a WindowsPath - object. You can also instantiate a PosixPath or WindowsPath directly, - but cannot instantiate a WindowsPath on a POSIX system or vice versa. + methods to do system calls on path objects. Instantiating a Path will + return a PosixPath object. You can also instantiate a PosixPath + directly. """ __slots__ = ( '_accessor', @@ -1087,7 +894,7 @@ class Path(PurePath): def __new__(cls, *args, **kwargs): if cls is Path: - cls = WindowsPath if os.name == 'nt' else PosixPath + cls = PosixPath self = cls._from_parts(args, init=False) if not self._flavour.is_supported: raise NotImplementedError("cannot instantiate %r on your system" @@ -1209,8 +1016,6 @@ def absolute(self): # XXX untested yet! if self.is_absolute(): return self - # FIXME this must defer to the specific flavour (and, under Windows, - # use nt._getfullpathname()) obj = self._from_parts([os.getcwd()] + self._parts, init=False) obj._init(template=self) return obj @@ -1585,12 +1390,3 @@ class PosixPath(Path, PurePosixPath): """ __slots__ = () -class WindowsPath(Path, PureWindowsPath): - """Path subclass for Windows systems. - - On a Windows system, instantiating a Path should return this object. - """ - __slots__ = () - - def is_mount(self): - raise NotImplementedError("Path.is_mount() is unsupported on this system") From 09bd9f27ac34102a66feccc46ba0415d88291d31 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Fri, 21 Apr 2023 12:31:06 +0200 Subject: [PATCH 65/65] more linting --- contextlib/contextlib.py | 6 +++--- functools/functools.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/contextlib/contextlib.py b/contextlib/contextlib.py index 751cb5500b..33c2300b78 100644 --- a/contextlib/contextlib.py +++ b/contextlib/contextlib.py @@ -11,7 +11,7 @@ from collections import deque from functools import wraps -from ucontextlib import * +from ucontextlib import AbstractContextManager, MethodType class AbstractAsyncContextManager(object): @@ -287,7 +287,7 @@ def _fix_exception_context(new_exc, old_exc): suppressed_exc = True pending_raise = False exc_details = (None, None, None) - except: + except: # noqa:E722 new_exc_details = sys.exc_info() # simulate the stack of exceptions by setting the context _fix_exception_context(new_exc_details[1], exc_details[1]) @@ -427,7 +427,7 @@ def _fix_exception_context(new_exc, old_exc): suppressed_exc = True pending_raise = False exc_details = (None, None, None) - except: + except: # noqa:E722 new_exc_details = sys.exc_info() # simulate the stack of exceptions by setting the context _fix_exception_context(new_exc_details[1], exc_details[1]) diff --git a/functools/functools.py b/functools/functools.py index 8b1e6b50ed..a3cc550f7d 100644 --- a/functools/functools.py +++ b/functools/functools.py @@ -34,21 +34,31 @@ def reduce(function, iterable, initializer=None): # CPython git v3.4.10 def cmp_to_key(mycmp): """Convert a cmp= function into a key= function""" + class K(object): __slots__ = ['obj'] + def __init__(self, obj): self.obj = obj + def __lt__(self, other): return mycmp(self.obj, other.obj) < 0 + def __gt__(self, other): return mycmp(self.obj, other.obj) > 0 + def __eq__(self, other): return mycmp(self.obj, other.obj) == 0 + def __le__(self, other): return mycmp(self.obj, other.obj) <= 0 + def __ge__(self, other): return mycmp(self.obj, other.obj) >= 0 + def __ne__(self, other): return mycmp(self.obj, other.obj) != 0 + __hash__ = None + return K