diff --git a/release-notes/next-release.md b/release-notes/next-release.md index 9c2caf45a..c43f45fe8 100644 --- a/release-notes/next-release.md +++ b/release-notes/next-release.md @@ -11,6 +11,8 @@ ## Other +- Updated type hinting for the `DictList` class so that the type of `Object` contained by a `DictList` can be specified. For example, the hinted return type of `model.reactions.get_by_id` is now `Reaction` instead of `Object`. + ## Deprecated features - Changed the type of the `loopless` parameter in `flux_variability_analysis` from `bool` to `Optional[str]`. Using `loopless=False` or `loopless=True` (boolean) is now deprecated. diff --git a/src/cobra/core/dictlist.py b/src/cobra/core/dictlist.py index b4a5d0468..0684b1a34 100644 --- a/src/cobra/core/dictlist.py +++ b/src/cobra/core/dictlist.py @@ -12,13 +12,17 @@ Pattern, Tuple, Type, + TypeVar, Union, ) from .object import Object -class DictList(list): +CobraObject = TypeVar("CobraObject", bound=Object) + + +class DictList(List[CobraObject]): """ Define a combined dict and list. @@ -49,12 +53,12 @@ def __init__(self, *args): self.extend(other) # noinspection PyShadowingBuiltins - def has_id(self, id: Union[Object, str]) -> bool: + def has_id(self, id: Union[CobraObject, str]) -> bool: """Check if id is in DictList.""" return id in self._dict # noinspection PyShadowingBuiltins - def _check(self, id: Union[Object, str]) -> None: + def _check(self, id: Union[CobraObject, str]) -> None: """Make sure duplicate id's are not added. This function is called before adding in elements. @@ -68,7 +72,7 @@ def _generate_index(self) -> None: self._dict = {v.id: k for k, v in enumerate(self)} # noinspection PyShadowingBuiltins - def get_by_id(self, id: Union[Object, str]) -> Object: + def get_by_id(self, id: Union[CobraObject, str]) -> CobraObject: """Return the element with a matching id.""" return list.__getitem__(self, self._dict[id]) @@ -76,7 +80,9 @@ def list_attr(self, attribute: str) -> list: """Return a list of the given attribute for every object.""" return [getattr(i, attribute) for i in self] - def get_by_any(self, iterable: List[Union[str, Object, int]]) -> list: + def get_by_any( + self, iterable: List[Union[str, CobraObject, int]] + ) -> List[CobraObject]: """Get a list of members using several different ways of indexing. Parameters @@ -92,7 +98,7 @@ def get_by_any(self, iterable: List[Union[str, Object, int]]) -> list: a list of members """ - def get_item(item: Any) -> Any: + def get_item(item: Any) -> CobraObject: if isinstance(item, int): return self[item] elif isinstance(item, str): @@ -110,7 +116,7 @@ def query( self, search_function: Union[str, Pattern, Callable], attribute: Union[str, None] = None, - ) -> "DictList": + ) -> "DictList[CobraObject]": """Query the list. Parameters @@ -167,21 +173,21 @@ def select_attribute(x: Optional[Any]) -> Any: results._extend_nocheck(matches) return results - def _replace_on_id(self, new_object: Object) -> None: + def _replace_on_id(self, new_object: CobraObject) -> None: """Replace an object by another with the same id.""" the_id = new_object.id the_index = self._dict[the_id] list.__setitem__(self, the_index, new_object) # overriding default list functions with new ones - def append(self, entity: Object) -> None: + def append(self, entity: CobraObject) -> None: """Append object to end.""" the_id = entity.id self._check(the_id) self._dict[the_id] = len(self) list.append(self, entity) - def union(self, iterable: Iterable[Object]) -> None: + def union(self, iterable: Iterable[CobraObject]) -> None: """Add elements with id's not already in the model.""" _dict = self._dict append = self.append @@ -189,7 +195,7 @@ def union(self, iterable: Iterable[Object]) -> None: if i.id not in _dict: append(i) - def extend(self, iterable: Iterable[Object]) -> None: + def extend(self, iterable: Iterable[CobraObject]) -> None: """Extend list by appending elements from the iterable. Sometimes during initialization from an older pickle, _dict @@ -222,7 +228,7 @@ def extend(self, iterable: Iterable[Object]) -> None: f"Is it present twice?" ) - def _extend_nocheck(self, iterable: Iterable[Object]) -> None: + def _extend_nocheck(self, iterable: Iterable[CobraObject]) -> None: """Extend without checking for uniqueness. This function should only be used internally by DictList when it @@ -244,7 +250,7 @@ def _extend_nocheck(self, iterable: Iterable[Object]) -> None: for i, obj in enumerate(islice(self, current_length, None), current_length): _dict[obj.id] = i - def __sub__(self, other: Iterable[Object]) -> "DictList": + def __sub__(self, other: Iterable[CobraObject]) -> "DictList[CobraObject]": """Remove a value or values, and returns the new DictList. x.__sub__(y) <==> x - y @@ -264,7 +270,7 @@ def __sub__(self, other: Iterable[Object]) -> "DictList": total.remove(item) return total - def __isub__(self, other: Iterable[Object]) -> "DictList": + def __isub__(self, other: Iterable[CobraObject]) -> "DictList[CobraObject]": """Remove a value or values in place. x.__sub__(y) <==> x -= y @@ -278,7 +284,7 @@ def __isub__(self, other: Iterable[Object]) -> "DictList": self.remove(item) return self - def __add__(self, other: Iterable[Object]) -> "DictList": + def __add__(self, other: Iterable[CobraObject]) -> "DictList[CobraObject]": """Add item while returning a new DictList. x.__add__(y) <==> x + y @@ -294,7 +300,7 @@ def __add__(self, other: Iterable[Object]) -> "DictList": total.extend(other) return total - def __iadd__(self, other: Iterable[Object]) -> "DictList": + def __iadd__(self, other: Iterable[CobraObject]) -> "DictList[CobraObject]": """Add item while returning the same DictList. x.__iadd__(y) <==> x += y @@ -309,7 +315,7 @@ def __iadd__(self, other: Iterable[Object]) -> "DictList": self.extend(other) return self - def __reduce__(self) -> Tuple[Type["DictList"], Tuple, dict, Iterator]: + def __reduce__(self) -> Tuple[Type["DictList"], Tuple, dict, Iterator[CobraObject]]: """Return a reduced version of DictList. This reduced version details the class, an empty Tuple, a dictionary of the @@ -336,7 +342,7 @@ def __setstate__(self, state: dict) -> None: self._generate_index() # noinspection PyShadowingBuiltins - def index(self, id: Union[str, Object], *args) -> int: + def index(self, id: Union[str, CobraObject], *args) -> int: """Determine the position in the list. Parameters @@ -360,7 +366,7 @@ def index(self, id: Union[str, Object], *args) -> int: except KeyError: raise ValueError(f"{str(id)} not found") - def __contains__(self, entity: Union[str, Object]) -> bool: + def __contains__(self, entity: Union[str, CobraObject]) -> bool: """Ask if the DictList contain an entity. DictList.__contains__(entity) <==> entity in DictList @@ -377,14 +383,14 @@ def __contains__(self, entity: Union[str, Object]) -> bool: the_id = entity return the_id in self._dict - def __copy__(self) -> "DictList": + def __copy__(self) -> "DictList[CobraObject]": """Copy the DictList into a new one.""" the_copy = DictList() list.extend(the_copy, self) the_copy._dict = self._dict.copy() return the_copy - def insert(self, index: int, entity: Object) -> None: + def insert(self, index: int, entity: CobraObject) -> None: """Insert entity before index.""" self._check(entity.id) list.insert(self, index, entity) @@ -395,7 +401,7 @@ def insert(self, index: int, entity: Object) -> None: _dict[i] = j + 1 _dict[entity.id] = index - def pop(self, *args) -> Object: + def pop(self, *args) -> CobraObject: """Remove and return item at index (default last).""" value = list.pop(self, *args) index = self._dict.pop(value.id) @@ -409,11 +415,11 @@ def pop(self, *args) -> Object: _dict[i] = j - 1 return value - def add(self, x: Object) -> None: + def add(self, x: CobraObject) -> None: """Opposite of `remove`. Mirrors set.add.""" self.extend([x]) - def remove(self, x: Union[str, Object]) -> None: + def remove(self, x: Union[str, CobraObject]) -> None: """.. warning :: Internal use only. Each item is unique in the list which allows this @@ -445,8 +451,8 @@ def key(i): self._generate_index() def __getitem__( - self, i: Union[int, slice, Iterable, Object, "DictList"] - ) -> Union["DictList", Object]: + self, i: Union[int, slice, Iterable, CobraObject, "DictList[CobraObject]"] + ) -> Union["DictList[CobraObject]", CobraObject]: """Get item from DictList.""" if isinstance(i, int): return list.__getitem__(self, i) @@ -465,7 +471,9 @@ def __getitem__( else: return list.__getitem__(self, i) - def __setitem__(self, i: Union[slice, int], y: Union[list, Object]) -> None: + def __setitem__( + self, i: Union[slice, int], y: Union[List[CobraObject], CobraObject] + ) -> None: """Set an item via index or slice. Parameters @@ -507,11 +515,13 @@ def __delitem__(self, index: int) -> None: if j > index: _dict[i] = j - 1 - def __getslice__(self, i: int, j: int) -> "DictList": + def __getslice__(self, i: int, j: int) -> "DictList[CobraObject]": """Get a slice from it to j of DictList.""" return self.__getitem__(slice(i, j)) - def __setslice__(self, i: int, j: int, y: Union[list, Object]) -> None: + def __setslice__( + self, i: int, j: int, y: Union[List[CobraObject], CobraObject] + ) -> None: """Set slice, where y is an iterable.""" self.__setitem__(slice(i, j), y) @@ -519,7 +529,7 @@ def __delslice__(self, i: int, j: int) -> None: """Remove slice.""" self.__delitem__(slice(i, j)) - def __getattr__(self, attr: Any) -> Any: + def __getattr__(self, attr: Any) -> CobraObject: """Get an attribute by id.""" try: return DictList.get_by_id(self, attr) diff --git a/src/cobra/core/model.py b/src/cobra/core/model.py index 2d1215a12..8c2d26429 100644 --- a/src/cobra/core/model.py +++ b/src/cobra/core/model.py @@ -81,10 +81,10 @@ def __init__( self._solver = id_or_model.solver else: Object.__init__(self, id_or_model, name=name) - self.genes = DictList() - self.reactions = DictList() # A list of cobra.Reactions - self.metabolites = DictList() # A list of cobra.Metabolites - self.groups = DictList() # A list of cobra.Groups + self.genes: DictList[Gene] = DictList() + self.reactions: DictList[Reaction] = DictList() + self.metabolites: DictList[Metabolite] = DictList() + self.groups: DictList[Group] = DictList() # genes based on their ids {Gene.id: Gene} self._compartments = {} self._contexts = []