Skip to content

Commit 75a1a11

Browse files
authored
Merge pull request #3525 from aatle/sprite-stubs
Improve sprite stubs
2 parents fb74739 + 286e5ad commit 75a1a11

File tree

1 file changed

+65
-117
lines changed

1 file changed

+65
-117
lines changed

buildconfig/stubs/pygame/sprite.pyi

Lines changed: 65 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ from pygame.rect import FRect, Rect
2727
from pygame.surface import Surface
2828
from pygame.typing import Point, RectLike
2929

30-
# define some useful protocols first, which sprite functions accept
31-
# sprite functions don't need all sprite attributes to be present in the
32-
# arguments passed, they only use a few which are marked in the below protocols
30+
# Some sprite functions only need objects with certain attributes, not always a sprite
3331
class _HasRect(Protocol):
3432
@property
3533
def rect(self) -> Optional[Union[FRect, Rect]]: ...
@@ -41,51 +39,10 @@ class _HasImageAndRect(_HasRect, Protocol):
4139

4240
# mask in addition to rect
4341
class _HasMaskAndRect(_HasRect, Protocol):
44-
mask: Mask
45-
46-
# radius in addition to rect
47-
class _HasRadiusAndRect(_HasRect, Protocol):
48-
radius: float
49-
50-
# non-generic Group, used in Sprite
51-
_Group = AbstractGroup[Any]
52-
53-
# protocol helps with structural subtyping for typevars in sprite group generics
54-
# and allows the use of any class with the required attributes and methods
55-
class _SupportsSprite(_HasImageAndRect, Protocol):
56-
@property
57-
def image(self) -> Optional[Surface]: ...
58-
@image.setter
59-
def image(self, value: Optional[Surface]) -> None: ...
6042
@property
61-
def rect(self) -> Optional[Union[FRect, Rect]]: ...
62-
@rect.setter
63-
def rect(self, value: Optional[Union[FRect, Rect]]) -> None: ...
64-
@property
65-
def layer(self) -> int: ...
66-
@layer.setter
67-
def layer(self, value: int) -> None: ...
68-
def add_internal(self, group: _Group) -> None: ...
69-
def remove_internal(self, group: _Group) -> None: ...
70-
def update(self, *args: Any, **kwargs: Any) -> None: ...
71-
def add(self, *groups: _Group) -> None: ...
72-
def remove(self, *groups: _Group) -> None: ...
73-
def kill(self) -> None: ...
74-
def alive(self) -> bool: ...
75-
def groups(self) -> list[_Group]: ...
76-
77-
# also a protocol
78-
class _SupportsDirtySprite(_SupportsSprite, Protocol):
79-
dirty: int
80-
blendmode: int
81-
source_rect: Union[FRect, Rect]
82-
visible: int
83-
_layer: int
84-
def _set_visible(self, val: int) -> None: ...
85-
def _get_visible(self) -> int: ...
43+
def mask(self) -> Mask: ...
8644

87-
# concrete sprite implementation class
88-
class Sprite(_SupportsSprite):
45+
class Sprite(_HasImageAndRect):
8946
@property
9047
def image(self) -> Optional[Surface]: ...
9148
@image.setter
@@ -98,52 +55,47 @@ class Sprite(_SupportsSprite):
9855
def layer(self) -> int: ...
9956
@layer.setter
10057
def layer(self, value: int) -> None: ...
101-
def __init__(self, *groups: _Group) -> None: ...
102-
def add_internal(self, group: _Group) -> None: ...
103-
def remove_internal(self, group: _Group) -> None: ...
58+
def __init__(self, *groups: _GroupOrGroups[Any]) -> None: ...
59+
def add_internal(self, group: AbstractGroup[Any]) -> None: ...
60+
def remove_internal(self, group: AbstractGroup[Any]) -> None: ...
10461
def update(self, *args: Any, **kwargs: Any) -> None: ...
105-
def add(self, *groups: _Group) -> None: ...
106-
def remove(self, *groups: _Group) -> None: ...
62+
def add(self, *groups: _GroupOrGroups[Any]) -> None: ...
63+
def remove(self, *groups: _GroupOrGroups[Any]) -> None: ...
10764
def kill(self) -> None: ...
10865
def alive(self) -> bool: ...
109-
def groups(self) -> list[AbstractGroup[_SupportsSprite]]: ...
66+
def groups(self) -> list[AbstractGroup[Sprite]]: ...
11067

111-
# concrete dirty sprite implementation class
112-
class DirtySprite(Sprite, _SupportsDirtySprite):
68+
class DirtySprite(Sprite):
11369
dirty: int
11470
blendmode: int
11571
source_rect: Union[FRect, Rect]
11672
visible: int
11773
_layer: int
118-
def _set_visible(self, val: int) -> None: ...
119-
def _get_visible(self) -> int: ...
12074

121-
# typevar bound to Sprite, _SupportsSprite Protocol ensures sprite
122-
# subclass passed to group has image and rect attributes
123-
_TSprite = TypeVar("_TSprite", bound=_SupportsSprite)
124-
_TSprite2 = TypeVar("_TSprite2", bound=_SupportsSprite)
125-
_TDirtySprite = TypeVar("_TDirtySprite", bound=_SupportsDirtySprite)
75+
_SpriteT = TypeVar("_SpriteT", bound=Sprite)
76+
_SpriteT2 = TypeVar("_SpriteT2", bound=Sprite)
77+
_DirtySpriteT = TypeVar("_DirtySpriteT", bound=DirtySprite)
12678

127-
# typevar for sprite or iterable of sprites, used in Group init, add and remove
128-
_SpriteOrIterable = Union[_TSprite, Iterable[_SpriteOrIterable[_TSprite]]]
79+
_GroupOrGroups = Union[AbstractGroup[_SpriteT], Iterable[_GroupOrGroups[_SpriteT]]]
80+
_SpriteOrSprites = Union[_SpriteT, Iterable[_SpriteOrSprites[_SpriteT]]]
12981

130-
class AbstractGroup(Generic[_TSprite]):
131-
spritedict: dict[_TSprite, Optional[Union[FRect, Rect]]]
82+
class AbstractGroup(Generic[_SpriteT]):
83+
spritedict: dict[_SpriteT, Optional[Union[FRect, Rect]]]
13284
lostsprites: list[Union[FRect, Rect]]
13385
def __class_getitem__(cls, item: Any, /) -> types.GenericAlias: ...
13486
def __init__(self) -> None: ...
13587
def __len__(self) -> int: ...
136-
def __iter__(self) -> Iterator[_TSprite]: ...
88+
def __iter__(self) -> Iterator[_SpriteT]: ...
13789
def __bool__(self) -> bool: ...
13890
def __contains__(self, item: Any) -> bool: ...
139-
def add_internal(self, sprite: _TSprite, layer: None = None) -> None: ...
140-
def remove_internal(self, sprite: _TSprite) -> None: ...
141-
def has_internal(self, sprite: _TSprite) -> bool: ...
91+
def add_internal(self, sprite: _SpriteT, layer: Optional[int] = None) -> None: ...
92+
def remove_internal(self, sprite: _SpriteT) -> None: ...
93+
def has_internal(self, sprite: _SpriteT) -> bool: ...
14294
def copy(self) -> Self: ...
143-
def sprites(self) -> list[_TSprite]: ...
144-
def add(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
145-
def remove(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
146-
def has(self, *sprites: _SpriteOrIterable[_TSprite]) -> bool: ...
95+
def sprites(self) -> list[_SpriteT]: ...
96+
def add(self, *sprites: _SpriteOrSprites[_SpriteT]) -> None: ...
97+
def remove(self, *sprites: _SpriteOrSprites[_SpriteT]) -> None: ...
98+
def has(self, *sprites: _SpriteOrSprites[_SpriteT]) -> bool: ...
14799
def update(self, *args: Any, **kwargs: Any) -> None: ...
148100
def draw(
149101
self, surface: Surface, bgd: Optional[Surface] = None, special_flags: int = 0
@@ -155,41 +107,39 @@ class AbstractGroup(Generic[_TSprite]):
155107
) -> None: ...
156108
def empty(self) -> None: ...
157109

158-
class Group(AbstractGroup[_TSprite]):
159-
def __init__(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
110+
class Group(AbstractGroup[_SpriteT]):
111+
def __init__(self, *sprites: _SpriteOrSprites[_SpriteT]) -> None: ...
160112

161-
# these are aliased in the code too
113+
# These deprecated types are just aliases in the code too
162114
@deprecated("Use `pygame.sprite.Group` instead")
163-
class RenderPlain(Group[_TSprite]): ...
115+
class RenderPlain(Group[_SpriteT]): ...
164116

165117
@deprecated("Use `pygame.sprite.Group` instead")
166-
class RenderClear(Group[_TSprite]): ...
118+
class RenderClear(Group[_SpriteT]): ...
167119

168-
class RenderUpdates(Group[_TSprite]): ...
120+
class RenderUpdates(Group[_SpriteT]): ...
169121

170122
@deprecated("Use `pygame.sprite.RenderUpdates` instead")
171-
class OrderedUpdates(RenderUpdates[_TSprite]): ...
172-
173-
class LayeredUpdates(AbstractGroup[_TSprite]):
174-
def __init__(
175-
self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any
176-
) -> None: ...
177-
def add(self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any) -> None: ...
178-
def get_sprites_at(self, pos: Point) -> list[_TSprite]: ...
179-
def get_sprite(self, idx: int) -> _TSprite: ...
180-
def remove_sprites_of_layer(self, layer_nr: int) -> list[_TSprite]: ...
123+
class OrderedUpdates(RenderUpdates[_SpriteT]): ...
124+
125+
class LayeredUpdates(AbstractGroup[_SpriteT]):
126+
def __init__(self, *sprites: _SpriteOrSprites[_SpriteT], **kwargs: Any) -> None: ...
127+
def add(self, *sprites: _SpriteOrSprites[_SpriteT], **kwargs: Any) -> None: ...
128+
def get_sprites_at(self, pos: Point) -> list[_SpriteT]: ...
129+
def get_sprite(self, idx: int) -> _SpriteT: ...
130+
def remove_sprites_of_layer(self, layer_nr: int) -> list[_SpriteT]: ...
181131
def layers(self) -> list[int]: ...
182-
def change_layer(self, sprite: _TSprite, new_layer: int) -> None: ...
183-
def get_layer_of_sprite(self, sprite: _TSprite) -> int: ...
132+
def change_layer(self, sprite: _SpriteT, new_layer: int) -> None: ...
133+
def get_layer_of_sprite(self, sprite: _SpriteT) -> int: ...
184134
def get_top_layer(self) -> int: ...
185135
def get_bottom_layer(self) -> int: ...
186-
def move_to_front(self, sprite: _TSprite) -> None: ...
187-
def move_to_back(self, sprite: _TSprite) -> None: ...
188-
def get_top_sprite(self) -> _TSprite: ...
189-
def get_sprites_from_layer(self, layer: int) -> list[_TSprite]: ...
136+
def move_to_front(self, sprite: _SpriteT) -> None: ...
137+
def move_to_back(self, sprite: _SpriteT) -> None: ...
138+
def get_top_sprite(self) -> _SpriteT: ...
139+
def get_sprites_from_layer(self, layer: int) -> list[_SpriteT]: ...
190140
def switch_layer(self, layer1_nr: int, layer2_nr: int) -> None: ...
191141

192-
class LayeredDirty(LayeredUpdates[_TDirtySprite]):
142+
class LayeredDirty(LayeredUpdates[_DirtySpriteT]):
193143
def draw(
194144
self,
195145
surface: Surface,
@@ -207,20 +157,19 @@ class LayeredDirty(LayeredUpdates[_TDirtySprite]):
207157
)
208158
def set_timing_treshold(self, time_ms: SupportsFloat) -> None: ...
209159

210-
class GroupSingle(AbstractGroup[_TSprite]):
211-
sprite: _TSprite
212-
def __init__(self, sprite: Optional[_TSprite] = None) -> None: ...
160+
class GroupSingle(AbstractGroup[_SpriteT]):
161+
sprite: Optional[_SpriteT]
162+
def __init__(self, sprite: Optional[_SpriteT] = None) -> None: ...
213163

214-
# argument to collide_rect must have rect attribute
215164
def collide_rect(left: _HasRect, right: _HasRect) -> bool: ...
216165

217166
class collide_rect_ratio:
218167
ratio: float
219168
def __init__(self, ratio: float) -> None: ...
220169
def __call__(self, left: _HasRect, right: _HasRect) -> bool: ...
221170

222-
# must have rect attribute, may optionally have radius attribute
223-
_SupportsCollideCircle = Union[_HasRect, _HasRadiusAndRect]
171+
# Must have rect attribute, may optionally have radius attribute
172+
_SupportsCollideCircle = _HasRect
224173

225174
def collide_circle(
226175
left: _SupportsCollideCircle, right: _SupportsCollideCircle
@@ -233,32 +182,31 @@ class collide_circle_ratio:
233182
self, left: _SupportsCollideCircle, right: _SupportsCollideCircle
234183
) -> bool: ...
235184

236-
# argument to collide_mask must either have mask or have image attribute, in
185+
# Arguments to collide_mask must either have mask or have image attribute, in
237186
# addition to mandatorily having a rect attribute
238187
_SupportsCollideMask = Union[_HasImageAndRect, _HasMaskAndRect]
239188

240189
def collide_mask(
241190
left: _SupportsCollideMask, right: _SupportsCollideMask
242191
) -> Optional[tuple[int, int]]: ...
243192

244-
# _HasRect typevar for sprite collide functions
245-
_THasRect = TypeVar("_THasRect", bound=_HasRect)
193+
_HasRectT = TypeVar("_HasRectT", bound=_HasRect)
246194

247195
def spritecollide(
248-
sprite: _THasRect,
249-
group: AbstractGroup[_TSprite],
196+
sprite: _HasRectT,
197+
group: AbstractGroup[_SpriteT],
250198
dokill: bool,
251-
collided: Optional[Callable[[_THasRect, _TSprite], Any]] = None,
252-
) -> list[_TSprite]: ...
199+
collided: Optional[Callable[[_HasRectT, _SpriteT], bool]] = None,
200+
) -> list[_SpriteT]: ...
253201
def groupcollide(
254-
groupa: AbstractGroup[_TSprite],
255-
groupb: AbstractGroup[_TSprite2],
202+
groupa: AbstractGroup[_SpriteT],
203+
groupb: AbstractGroup[_SpriteT2],
256204
dokilla: bool,
257205
dokillb: bool,
258-
collided: Optional[Callable[[_TSprite, _TSprite2], Any]] = None,
259-
) -> dict[_TSprite, list[_TSprite2]]: ...
206+
collided: Optional[Callable[[_SpriteT, _SpriteT2], bool]] = None,
207+
) -> dict[_SpriteT, list[_SpriteT2]]: ...
260208
def spritecollideany(
261-
sprite: _THasRect,
262-
group: AbstractGroup[_TSprite],
263-
collided: Optional[Callable[[_THasRect, _TSprite], Any]] = None,
264-
) -> Optional[_TSprite]: ...
209+
sprite: _HasRectT,
210+
group: AbstractGroup[_SpriteT],
211+
collided: Optional[Callable[[_HasRectT, _SpriteT], bool]] = None,
212+
) -> Optional[_SpriteT]: ...

0 commit comments

Comments
 (0)