From 17766dfe8a5760714e346d873e41df2a8605b9e7 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 19:40:37 +0200 Subject: [PATCH 01/87] Move utils.py to utils/ --- discord/{utils.py => utils/__init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord/{utils.py => utils/__init__.py} (100%) diff --git a/discord/utils.py b/discord/utils/__init__.py similarity index 100% rename from discord/utils.py rename to discord/utils/__init__.py From 195e2a007d33fa741d5f10fa024256ec5ddef52f Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 19:41:02 +0200 Subject: [PATCH 02/87] :fire: Remove `filter_params` --- discord/utils/__init__.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index 5dea905617..f051011405 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -1385,34 +1385,3 @@ def _filter(ctx: AutocompleteContext, item: Any) -> bool: return iter(itertools.islice(gen, 25)) return autocomplete_callback - - -def filter_params(params, **kwargs): - """A helper function to filter out and replace certain keyword parameters - - Parameters - ---------- - params: Dict[str, Any] - The initial parameters to filter. - **kwargs: Dict[str, Optional[str]] - Key to value pairs where the key's contents would be moved to the - value, or if the value is None, remove key's contents (see code example). - - Example - ------- - .. code-block:: python3 - - >>> params = {"param1": 12, "param2": 13} - >>> filter_params(params, param1="param3", param2=None) - {'param3': 12} - # values of 'param1' is moved to 'param3' - # and values of 'param2' are completely removed. - """ - for old_param, new_param in kwargs.items(): - if old_param in params: - if new_param is None: - params.pop(old_param) - else: - params[new_param] = params.pop(old_param) - - return params From fafa149a318827f810e2e51318dbcf0448dcb210 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 20:20:14 +0200 Subject: [PATCH 03/87] :recycle: Merge `time_snowflake` and `generate_snowflake`, move `basic_autocomplete` to `utils/public.py` --- discord/utils/__init__.py | 179 +++----------------------------------- 1 file changed, 11 insertions(+), 168 deletions(-) diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index f051011405..098401624b 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -29,7 +29,6 @@ import asyncio import collections.abc import datetime -from enum import Enum, auto import functools import itertools import json @@ -40,7 +39,7 @@ import warnings from base64 import b64encode from bisect import bisect_left -from dataclasses import field +from enum import Enum, auto from inspect import isawaitable as _isawaitable from inspect import signature as _signature from operator import attrgetter @@ -48,7 +47,6 @@ TYPE_CHECKING, Any, AsyncIterator, - Awaitable, Callable, Coroutine, ForwardRef, @@ -64,7 +62,8 @@ overload, ) -from .errors import HTTPException, InvalidArgument +from ..errors import HTTPException, InvalidArgument +from .public import basic_autocomplete, generate_snowflake, utcnow try: import msgspec @@ -80,7 +79,6 @@ "deprecated", "oauth_url", "snowflake_time", - "time_snowflake", "find", "get", "get_or_fetch", @@ -98,11 +96,8 @@ "format_dt", "generate_snowflake", "basic_autocomplete", - "filter_params", ) -DISCORD_EPOCH = 1420070400000 - class Undefined(Enum): MISSING = auto() @@ -113,6 +108,7 @@ def __bool__(self) -> Literal[False]: MISSING: Literal[Undefined.MISSING] = Undefined.MISSING + class _cached_property: def __init__(self, function): self.function = function @@ -445,37 +441,11 @@ def snowflake_time(id: int) -> datetime.datetime: return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) -def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: - """Returns a numeric snowflake pretending to be created at the given date. - - When using as the lower end of a range, use ``time_snowflake(high=False) - 1`` - to be inclusive, ``high=True`` to be exclusive. - - When using as the higher end of a range, use ``time_snowflake(high=True) + 1`` - to be inclusive, ``high=False`` to be exclusive - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - high: :class:`bool` - Whether to set the lower 22 bit to high or low. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2**22 - 1 if high else 0) - - def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: """A helper to return the first element found in the sequence that meets the predicate. For example: :: - member = discord.utils.find(lambda m: m.name == 'Mighty', channel.guild.members) + member = discord.utils.find(lambda m: m.name == "Mighty", channel.guild.members) would find the first :class:`~discord.Member` whose name is 'Mighty' and return it. If an entry is not found, then ``None`` is returned. @@ -519,19 +489,19 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None: .. code-block:: python3 - member = discord.utils.get(message.guild.members, name='Foo') + member = discord.utils.get(message.guild.members, name="Foo") Multiple attribute matching: .. code-block:: python3 - channel = discord.utils.get(guild.voice_channels, name='Foo', bitrate=64000) + channel = discord.utils.get(guild.voice_channels, name="Foo", bitrate=64000) Nested attribute matching: .. code-block:: python3 - channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general') + channel = discord.utils.get(client.get_all_channels(), guild__name="Cool", name="general") Parameters ----------- @@ -602,11 +572,11 @@ async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> An Getting a guild from a guild ID: :: - guild = await utils.get_or_fetch(client, 'guild', guild_id) + guild = await utils.get_or_fetch(client, "guild", guild_id) Getting a channel from the guild. If the channel is not found, return None: :: - channel = await utils.get_or_fetch(guild, 'channel', channel_id, default=None) + channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) """ getter = getattr(obj, f"get_{attr}")(id) if getter is None: @@ -749,22 +719,6 @@ async def sleep_until(when: datetime.datetime, result: T | None = None) -> T | N return await asyncio.sleep(delta, result) -def utcnow() -> datetime.datetime: - """A helper function to return an aware UTC datetime representing the current time. - - This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware - datetime, compared to the naive datetime in the standard library. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`datetime.datetime` - The current aware datetime in UTC. - """ - return datetime.datetime.now(datetime.timezone.utc) - - def valid_icon_size(size: int) -> bool: """Icons must be power of 2 within [16, 4096].""" return not size & (size - 1) and 4096 >= size >= 16 @@ -1261,7 +1215,7 @@ def format_dt( ---------- dt: Union[:class:`datetime.datetime`, :class:`datetime.time`] The datetime to format. - style: :class:`str` + style: :class:`str`R The style to format the datetime with. Returns @@ -1274,114 +1228,3 @@ def format_dt( if style is None: return f"" return f"" - - -def generate_snowflake(dt: datetime.datetime | None = None) -> int: - """Returns a numeric snowflake pretending to be created at the given date but more accurate and random - than :func:`time_snowflake`. If dt is not passed, it makes one from the current time using utcnow. - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - - dt = dt or utcnow() - return int(dt.timestamp() * 1000 - DISCORD_EPOCH) << 22 | 0x3FFFFF - - -V = Union[Iterable[OptionChoice], Iterable[str], Iterable[int], Iterable[float]] -AV = Awaitable[V] -Values = Union[V, Callable[[AutocompleteContext], Union[V, AV]], AV] -AutocompleteFunc = Callable[[AutocompleteContext], AV] -FilterFunc = Callable[[AutocompleteContext, Any], Union[bool, Awaitable[bool]]] - - -def basic_autocomplete( - values: Values, *, filter: FilterFunc | None = None -) -> AutocompleteFunc: - """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and - will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is - callable, it will be called with the AutocompleteContext. - - This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. - - Parameters - ---------- - values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a - single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. - filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] - An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: - the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. - - .. versionadded:: 2.7 - - Returns - ------- - Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - A wrapped callback for the autocomplete. - - Examples - -------- - - Basic usage: - - .. code-block:: python3 - - Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) - - # or - - async def autocomplete(ctx): - return "foo", "bar", "baz", ctx.interaction.user.name - - Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) - - With filter parameter: - - .. code-block:: python3 - - Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i)) - - .. versionadded:: 2.0 - - Note - ---- - Autocomplete cannot be used for options that have specified choices. - """ - - async def autocomplete_callback(ctx: AutocompleteContext) -> V: - _values = values # since we reassign later, python considers it local if we don't do this - - if callable(_values): - _values = _values(ctx) - if asyncio.iscoroutine(_values): - _values = await _values - - if filter is None: - - def _filter(ctx: AutocompleteContext, item: Any) -> bool: - item = getattr(item, "name", item) - return str(item).lower().startswith(str(ctx.value or "").lower()) - - gen = (val for val in _values if _filter(ctx, val)) - - elif asyncio.iscoroutinefunction(filter): - gen = (val for val in _values if await filter(ctx, val)) - - elif callable(filter): - gen = (val for val in _values if filter(ctx, val)) - - else: - raise TypeError("``filter`` must be callable.") - - return iter(itertools.islice(gen, 25)) - - return autocomplete_callback From e2d8eb1b285202b2ee6eb4370ce898de2f087499 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 20:29:59 +0200 Subject: [PATCH 04/87] :recycle: Merge `time_snowflake` and `generate_snowflake`, move `basic_autocomplete` to `utils/public.py` --- CHANGELOG-V3.md | 3 + discord/iterators.py | 32 ++--- discord/onboarding.py | 7 +- discord/utils/public.py | 176 +++++++++++++++++++++++ docs/locales/en/LC_MESSAGES/api/utils.po | 2 +- tests/test_utils.py | 2 +- 6 files changed, 201 insertions(+), 21 deletions(-) create mode 100644 CHANGELOG-V3.md create mode 100644 discord/utils/public.py diff --git a/CHANGELOG-V3.md b/CHANGELOG-V3.md new file mode 100644 index 0000000000..d9eb9623e0 --- /dev/null +++ b/CHANGELOG-V3.md @@ -0,0 +1,3 @@ +### Removed + +- `utils.filter_params` diff --git a/discord/iterators.py b/discord/iterators.py index 193b2ac078..77e5f23279 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -41,7 +41,7 @@ from .audit_logs import AuditLogEntry from .errors import NoMoreItems from .object import Object -from .utils import maybe_coroutine, snowflake_time, time_snowflake +from .utils import generate_snowflake, maybe_coroutine, snowflake_time __all__ = ( "ReactionIterator", @@ -341,11 +341,11 @@ def __init__( oldest_first=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) if isinstance(around, datetime.datetime): - around = Object(id=time_snowflake(around)) + around = Object(id=generate_snowflake(around)) self.reverse = after is not None if oldest_first is None else oldest_first self.messageable = messageable @@ -483,9 +483,9 @@ def __init__( action_type=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.loop = guild._state.loop @@ -596,9 +596,9 @@ class GuildIterator(_AsyncIterator["Guild"]): def __init__(self, bot, limit, before=None, after=None, with_counts=True): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.bot = bot self.limit = limit @@ -687,7 +687,7 @@ async def _retrieve_guilds_after_strategy(self, retrieve): class MemberIterator(_AsyncIterator["Member"]): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.limit = limit @@ -821,7 +821,7 @@ def __init__( self.before = None elif isinstance(before, datetime.datetime): if joined: - self.before = str(time_snowflake(before, high=False)) + self.before = str(generate_snowflake(before, high=False)) else: self.before = before.isoformat() else: @@ -897,9 +897,9 @@ def __init__( after: datetime.datetime | int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.event = event self.limit = limit @@ -991,9 +991,9 @@ def __init__( self.sku_ids = sku_ids if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.before = before self.after = after @@ -1109,9 +1109,9 @@ def __init__( user_id: int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.state = state self.sku_id = sku_id diff --git a/discord/onboarding.py b/discord/onboarding.py index b82de31cfc..e86538bcc4 100644 --- a/discord/onboarding.py +++ b/discord/onboarding.py @@ -26,10 +26,11 @@ from typing import TYPE_CHECKING, Any +from discord import utils + from .enums import OnboardingMode, PromptType, try_enum from .partial_emoji import PartialEmoji from .utils import MISSING, cached_property, generate_snowflake, get -from discord import utils if TYPE_CHECKING: from .abc import Snowflake @@ -81,7 +82,7 @@ def __init__( id: int | None = None, ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.title: str = title self.channels: list[Snowflake] = channels or [] self.roles: list[Snowflake] = roles or [] @@ -169,7 +170,7 @@ def __init__( id: int | None = None, # Currently optional as users can manually create these ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.type: PromptType = type if isinstance(self.type, int): diff --git a/discord/utils/public.py b/discord/utils/public.py new file mode 100644 index 0000000000..4fc40874fc --- /dev/null +++ b/discord/utils/public.py @@ -0,0 +1,176 @@ +import asyncio +import datetime +import itertools +from collections.abc import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from ..commands.context import AutocompleteContext + from ..commands.options import OptionChoice + +DISCORD_EPOCH = 1420070400000 + + +def utcnow() -> datetime.datetime: + """A helper function to return an aware UTC datetime representing the current time. + + This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware + datetime, compared to the naive datetime in the standard library. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`datetime.datetime` + The current aware datetime in UTC. + """ + return datetime.datetime.now(datetime.timezone.utc) + + +V = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] +AV = Awaitable[V] +Values = V | Callable[["AutocompleteContext"], V | AV] | AV +AutocompleteFunc = Callable[["AutocompleteContext"], AV] +FilterFunc = Callable[["AutocompleteContext", Any], bool | Awaitable[bool]] + + +def basic_autocomplete( + values: Values, *, filter: FilterFunc | None = None +) -> AutocompleteFunc: + """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and + will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is + callable, it will be called with the AutocompleteContext. + + This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. + + Parameters + ---------- + values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a + single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. + filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] + An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: + the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. + + .. versionadded:: 2.7 + + Returns + ------- + Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + A wrapped callback for the autocomplete. + + Examples + -------- + + Basic usage: + + .. code-block:: python3 + + Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) + + # or + + async def autocomplete(ctx): + return "foo", "bar", "baz", ctx.interaction.user.name + + Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) + + With filter parameter: + + .. code-block:: python3 + + Option( + str, + "color", + autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i), + ) + + .. versionadded:: 2.0 + + Note + ---- + Autocomplete cannot be used for options that have specified choices. + """ + + async def autocomplete_callback(ctx: AutocompleteContext) -> V: + _values = values # since we reassign later, python considers it local if we don't do this + + if callable(_values): + _values = _values(ctx) + if asyncio.iscoroutine(_values): + _values = await _values + + if filter is None: + + def _filter(ctx: AutocompleteContext, item: Any) -> bool: + item = getattr(item, "name", item) + return str(item).lower().startswith(str(ctx.value or "").lower()) + + gen = (val for val in _values if _filter(ctx, val)) + + elif asyncio.iscoroutinefunction(filter): + gen = (val for val in _values if await filter(ctx, val)) + + elif callable(filter): + gen = (val for val in _values if filter(ctx, val)) + + else: + raise TypeError("``filter`` must be callable.") + + return iter(itertools.islice(gen, 25)) + + return autocomplete_callback + + +def generate_snowflake( + dt: datetime.datetime | None = None, + *, + mode: Literal["boundary", "realistic"] = "boundary", + high: bool = False, +) -> int: + """Returns a numeric snowflake pretending to be created at the given date. + + This function can generate both realistic snowflakes (for general use) and + boundary snowflakes (for range queries). + + Parameters + ---------- + dt: :class:`datetime.datetime` + A datetime object to convert to a snowflake. + If naive, the timezone is assumed to be local time. + If None, uses current UTC time. + mode: :class:`str` + The type of snowflake to generate: + - "realistic": Creates a snowflake with random-like lower bits (default) + - "boundary": Creates a snowflake for range queries + high: :class:`bool` + Only used when mode="boundary". Whether to set the lower 22 bits + to high (True) or low (False). Default is False. + + Returns + ------- + :class:`int` + The snowflake representing the time given. + + Examples + -------- + # Generate realistic snowflake + snowflake = generate_snowflake(dt) + + # Generate boundary snowflakes + lower_bound = generate_snowflake(dt, mode="boundary", high=False) + upper_bound = generate_snowflake(dt, mode="boundary", high=True) + + # For inclusive ranges: + # Lower: generate_snowflake(dt, mode="boundary", high=False) - 1 + # Upper: generate_snowflake(dt, mode="boundary", high=True) + 1 + """ + dt = dt or utcnow() + discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) + + if mode == "realistic": + return (discord_millis << 22) | 0x3FFFFF + elif mode == "boundary": + return (discord_millis << 22) + (2**22 - 1 if high else 0) + else: + raise ValueError(f"Invalid mode '{mode}'. Must be 'realistic' or 'boundary'") diff --git a/docs/locales/en/LC_MESSAGES/api/utils.po b/docs/locales/en/LC_MESSAGES/api/utils.po index 54425aad81..c88c180521 100644 --- a/docs/locales/en/LC_MESSAGES/api/utils.po +++ b/docs/locales/en/LC_MESSAGES/api/utils.po @@ -613,7 +613,7 @@ msgstr "" msgid "The formatted string." msgstr "" -#: 3dbe398f94684d7d81111c3a9d78ddee discord.utils.time_snowflake:1 of +#: 3dbe398f94684d7d81111c3a9d78ddee discord.utils.:1 of msgid "Returns a numeric snowflake pretending to be created at the given date." msgstr "" diff --git a/tests/test_utils.py b/tests/test_utils.py index d0f94acb93..c68bf1fa36 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,10 +36,10 @@ async_all, copy_doc, find, + generate_snowflake, get, maybe_coroutine, snowflake_time, - time_snowflake, utcnow, ) From 86bd1683f5c5c2ac6cd911d03d2dd237df705190 Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 20:51:08 +0200 Subject: [PATCH 05/87] :fire: Remove `utils.sleep_until` --- CHANGELOG-V3.md | 1 + discord/utils/__init__.py | 22 ---------------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/CHANGELOG-V3.md b/CHANGELOG-V3.md index d9eb9623e0..a7368c6955 100644 --- a/CHANGELOG-V3.md +++ b/CHANGELOG-V3.md @@ -1,3 +1,4 @@ ### Removed - `utils.filter_params` +- `utils.sleep_until` diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index 098401624b..a9f36ab4ef 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -697,28 +697,6 @@ def compute_timedelta(dt: datetime.datetime): now = datetime.datetime.now(datetime.timezone.utc) return max((dt - now).total_seconds(), 0) - -async def sleep_until(when: datetime.datetime, result: T | None = None) -> T | None: - """|coro| - - Sleep until a specified time. - - If the time supplied is in the past this function will yield instantly. - - .. versionadded:: 1.3 - - Parameters - ---------- - when: :class:`datetime.datetime` - The timestamp in which to sleep until. If the datetime is naive then - it is assumed to be local time. - result: Any - If provided is returned to the caller when the coroutine completes. - """ - delta = compute_timedelta(when) - return await asyncio.sleep(delta, result) - - def valid_icon_size(size: int) -> bool: """Icons must be power of 2 within [16, 4096].""" return not size & (size - 1) and 4096 >= size >= 16 From 23c2c49575c5951d2181af9554ef5c43f3ab732c Mon Sep 17 00:00:00 2001 From: Paillat Date: Thu, 29 May 2025 21:00:29 +0200 Subject: [PATCH 06/87] chore: Start migration to uv & ruff & hatch (#4) * Start migration to uv * Setup ruff and hatch * Change pre-commit to use ruff * Format with ruff * Fix mistake * Add dev deps * Change workflows to use uv and ruff * :heavy_plus_sign: Add colorlog and remove requirements folder and fix build * :green_heart: Fix sphinx build ? * :bug: Add __version.py for version management and update import in __init__.py * :pencil2: Update lib-checks.yml to run ruff on ubuntu-latest * :bug: Update lib-checks.yml to run mypy with uv * :fire: Delete MANIFEST.in * :sparkles: Enhance lib-checks.yml to include ruff formatter check * :recycle: Refactor pyproject.toml and uv.lock to use optional-dependencies for voice and speed --- .github/workflows/docs-checks.yml | 17 +- .github/workflows/lib-checks.yml | 111 +- .gitignore | 3 +- .pre-commit-config.yaml | 75 +- MANIFEST.in | 21 - discord/__init__.py | 2 +- discord/__main__.py | 34 +- discord/{_version.py => __version.py} | 125 +- discord/abc.py | 157 +- discord/activity.py | 68 +- discord/appinfo.py | 20 +- discord/application_role_connection.py | 4 +- discord/asset.py | 28 +- discord/audit_logs.py | 104 +- discord/automod.py | 38 +- discord/banners.py | 64 +- discord/bot.py | 175 +- discord/channel.py | 264 +-- discord/client.py | 137 +- discord/cog.py | 110 +- discord/colour.py | 4 +- discord/commands/context.py | 19 +- discord/commands/core.py | 235 +- discord/commands/options.py | 75 +- discord/commands/permissions.py | 9 +- discord/components.py | 15 +- discord/embeds.py | 54 +- discord/emoji.py | 24 +- discord/enums.py | 36 +- discord/errors.py | 13 +- discord/ext/bridge/bot.py | 8 +- discord/ext/bridge/context.py | 12 +- discord/ext/bridge/core.py | 58 +- discord/ext/commands/_types.py | 4 +- discord/ext/commands/bot.py | 23 +- discord/ext/commands/context.py | 8 +- discord/ext/commands/converter.py | 120 +- discord/ext/commands/cooldowns.py | 24 +- discord/ext/commands/core.py | 250 +-- discord/ext/commands/errors.py | 52 +- discord/ext/commands/flags.py | 65 +- discord/ext/commands/help.py | 71 +- discord/ext/commands/view.py | 5 +- discord/ext/pages/pagination.py | 176 +- discord/ext/tasks/__init__.py | 66 +- discord/file.py | 11 +- discord/flags.py | 28 +- discord/gateway.py | 42 +- discord/guild.py | 333 +-- discord/http.py | 395 +--- discord/integrations.py | 12 +- discord/interactions.py | 109 +- discord/invite.py | 62 +- discord/iterators.py | 56 +- discord/member.py | 82 +- discord/mentions.py | 8 +- discord/message.py | 232 +- discord/monetization.py | 16 +- discord/object.py | 4 +- discord/onboarding.py | 38 +- discord/opus.py | 31 +- discord/partial_emoji.py | 14 +- discord/permissions.py | 19 +- discord/player.py | 51 +- discord/poll.py | 50 +- discord/raw_models.py | 24 +- discord/reaction.py | 8 +- discord/role.py | 32 +- discord/scheduled_events.py | 36 +- discord/shard.py | 40 +- discord/sinks/core.py | 4 +- discord/sinks/m4a.py | 16 +- discord/sinks/mka.py | 8 +- discord/sinks/mkv.py | 8 +- discord/sinks/mp3.py | 8 +- discord/sinks/mp4.py | 16 +- discord/sinks/ogg.py | 8 +- discord/sinks/wave.py | 4 +- discord/stage_instance.py | 21 +- discord/state.py | 261 +-- discord/sticker.py | 42 +- discord/team.py | 8 +- discord/template.py | 8 +- discord/threads.py | 58 +- discord/types/interactions.py | 24 +- discord/types/message.py | 4 +- discord/types/voice.py | 4 +- discord/ui/button.py | 12 +- discord/ui/input_text.py | 12 +- discord/ui/item.py | 4 +- discord/ui/modal.py | 25 +- discord/ui/select.py | 35 +- discord/ui/view.py | 46 +- discord/user.py | 40 +- discord/utils.py | 64 +- discord/voice_client.py | 51 +- discord/webhook/async_.py | 112 +- discord/webhook/sync.py | 79 +- discord/welcome_screen.py | 29 +- discord/widget.py | 45 +- docs/conf.py | 6 +- docs/extensions/attributetable.py | 12 +- docs/extensions/builder.py | 7 +- docs/extensions/details.py | 8 +- docs/extensions/nitpick_file_ignorer.py | 5 +- examples/app_commands/context_menus.py | 8 +- examples/app_commands/info.py | 4 +- examples/app_commands/slash_autocomplete.py | 21 +- examples/app_commands/slash_basic.py | 8 +- examples/app_commands/slash_cog.py | 4 +- examples/app_commands/slash_cog_groups.py | 16 +- examples/app_commands/slash_groups.py | 6 +- examples/app_commands/slash_options.py | 4 +- examples/audio_recording.py | 9 +- examples/audio_recording_merged.py | 4 +- examples/background_task.py | 4 +- examples/basic_voice.py | 20 +- examples/bridge_commands.py | 4 +- examples/converters.py | 20 +- examples/cooldown.py | 20 +- examples/create_private_emoji.py | 4 +- examples/custom_context.py | 16 +- examples/deleted.py | 4 +- examples/edits.py | 9 +- examples/modal_dialogs.py | 28 +- examples/new_member.py | 4 +- examples/reaction_roles.py | 16 +- examples/reply.py | 4 +- examples/secret.py | 25 +- examples/timeout.py | 4 +- examples/views/channel_select.py | 7 +- examples/views/confirm.py | 12 +- examples/views/counter.py | 4 +- examples/views/dropdown.py | 16 +- examples/views/ephemeral.py | 8 +- examples/views/paginator.py | 74 +- examples/views/persistent.py | 12 +- examples/views/role_select.py | 7 +- examples/wait_for_event.py | 12 +- pyproject.toml | 324 ++- requirements/_.txt | 3 - requirements/_locale.txt | 4 - requirements/_release.txt | 5 - requirements/all.txt | 4 - requirements/dev.txt | 11 - requirements/docs.txt | 11 - requirements/speed.txt | 2 - requirements/voice.txt | 1 - setup.py | 4 - tests/test_typing_annotated.py | 12 +- uv.lock | 2136 +++++++++++++++++++ 151 files changed, 4017 insertions(+), 4728 deletions(-) delete mode 100644 MANIFEST.in rename discord/{_version.py => __version.py} (52%) delete mode 100644 requirements/_.txt delete mode 100644 requirements/_locale.txt delete mode 100644 requirements/_release.txt delete mode 100644 requirements/all.txt delete mode 100644 requirements/dev.txt delete mode 100644 requirements/docs.txt delete mode 100644 requirements/speed.txt delete mode 100644 requirements/voice.txt delete mode 100644 setup.py create mode 100644 uv.lock diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index aa97542bfa..8b6c245602 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -41,19 +41,22 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.13" - cache: "pip" - cache-dependency-path: "requirements/docs.txt" - check-latest: true - - name: Install dependencies - run: | - python -m pip install -U pip - pip install ".[docs]" + - name: "Install uv" + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Sync dependencies + run: uv sync --no-python-downloads --group dev --group docs - name: "Check Links" + env: + SPHINXBUILD: ${{ github.workspace }}/.venv/bin/sphinx-build if: ${{ github.event_name == 'schedule' || inputs.with_linkcheck }} run: | cd docs make linkcheck - name: "Compile to html" + env: + SPHINXBUILD: ${{ github.workspace }}/.venv/bin/sphinx-build run: | cd docs make -e SPHINXOPTS="-D language='en'" html diff --git a/.github/workflows/lib-checks.yml b/.github/workflows/lib-checks.yml index ee48dbd47a..372b7fa281 100644 --- a/.github/workflows/lib-checks.yml +++ b/.github/workflows/lib-checks.yml @@ -37,18 +37,17 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.13" - cache: "pip" - cache-dependency-path: "requirements/dev.txt" - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install -r requirements/dev.txt + - name: "Install uv" + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Sync dependencies + run: uv sync --no-python-downloads --group dev - name: "Run codespell" run: - codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU" \ + uv run codespell --ignore-words-list="groupt,nd,ot,ro,falsy,BU" \ --exclude-file=".github/workflows/codespell.yml" - bandit: - if: ${{ github.event_name != 'schedule' }} + ruff: runs-on: ubuntu-latest steps: - name: "Checkout Repository" @@ -57,38 +56,16 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.13" - cache: "pip" - cache-dependency-path: "requirements/dev.txt" - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install -r requirements/dev.txt - - name: "Run bandit" - run: bandit --recursive --skip B101,B104,B105,B110,B307,B311,B404,B603,B607 . - pylint: - if: ${{ github.event_name != 'schedule' }} - runs-on: ubuntu-latest - steps: - - name: "Checkout Repository" - uses: actions/checkout@v4 - - name: "Setup Python" - uses: actions/setup-python@v5 + - name: "Install uv" + uses: astral-sh/setup-uv@v6 with: - python-version: "3.13" - cache: "pip" - cache-dependency-path: "requirements/dev.txt" - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install -r requirements/dev.txt - - name: "Setup cache" - id: cache-pylint - uses: actions/cache@v4 - with: - path: .pylint.d - key: pylint - - name: "Run pylint" - run: pylint discord/ --exit-zero + enable-cache: true + - name: Sync dependencies + run: uv sync --no-python-downloads --group dev + - name: "Run ruff linter check" + run: uv run ruff check discord/ + - name: "Run ruff formatter check" + run: uv run ruff format --check discord/ mypy: if: ${{ github.event_name != 'schedule' }} runs-on: ubuntu-latest @@ -99,12 +76,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.13" - cache: "pip" - cache-dependency-path: "requirements/dev.txt" - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install -r requirements/dev.txt + - name: "Install uv" + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Sync dependencies + run: uv sync --no-python-downloads --group dev - name: "Setup cache" id: cache-mypy uses: actions/cache@v4 @@ -115,44 +92,4 @@ jobs: id: cache-dir-mypy run: mkdir -p -v .mypy_cache - name: "Run mypy" - run: mypy --non-interactive discord/ - pytest: - strategy: - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - exclude: - - { python-version: "3.9", os: "macos-latest" } - include: - - { python-version: "3.9", os: "macos-13" } - runs-on: ${{ matrix.os }} - env: - OS: ${{ matrix.os }} - PYTHON: ${{ matrix.python-version }} - steps: - - name: "Checkout Repository" - uses: actions/checkout@v4 - - name: "Setup Python" - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - cache-dependency-path: "requirements/dev.txt" - check-latest: true - - name: "Install dependencies" - run: | - python -m pip install --upgrade pip - pip install flake8 - pip install -r requirements/dev.txt - - name: "Setup cache" - id: cache-pytest - uses: actions/cache@v4 - with: - path: .pytest_cache - key: ${{ matrix.os }}-${{ matrix.python-version }}-pytest - - name: "Lint with flake8" - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics + run: uv run mypy --non-interactive discord/ diff --git a/.gitignore b/.gitignore index 96be4f61db..7c27d21ddf 100644 --- a/.gitignore +++ b/.gitignore @@ -177,7 +177,6 @@ docs/_build __pycache__ .vs/* .vscode/* -test.py node_modules/* # changelog is autogenerated from CHANGELOG.md @@ -192,3 +191,5 @@ docs/build/linkcheck !docs/locales/* /build/ /vscode/ + +/discord/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f53f7cf96b..e43197fe9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,65 +10,12 @@ repos: exclude: \.(po|pot|yml|yaml)$ - id: end-of-file-fixer exclude: \.(po|pot|yml|yaml)$ - - repo: https://github.com/PyCQA/autoflake - rev: v2.3.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.9 hooks: - - id: autoflake - # args: - # - --in-place - # - --remove-all-unused-imports - # - --expand-star-imports - # - --remove-duplicate-keys - # - --remove-unused-variables - - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 - hooks: - - id: pyupgrade - exclude: \.(po|pot|yml|yaml)$ - - repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - exclude: \.(po|pot|yml|yaml)$ - - repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black - args: [--safe, --quiet] - exclude: \.(po|pot|yml|yaml)$ - - repo: https://github.com/Pierre-Sassoulas/black-disable-checker - rev: v1.1.3 - hooks: - - id: black-disable-checker - # - repo: https://github.com/PyCQA/flake8 - # rev: 4.0.1 - # hooks: - # - id: flake8 - # additional_dependencies: [flake8-typing-imports==1.12.0] - # - repo: local - # hooks: - # - id: pylint - # name: pylint - # entry: pylint - # language: system - # types: [python] - # args: ["-rn", "-sn", "--rcfile=.pylintrc", "--fail-on=I"] - # # We define an additional manual step to allow running pylint with a spelling - # # checker in CI. - # - id: pylint - # alias: pylint-with-spelling - # name: pylint - # entry: pylint - # language: system - # types: [python] - # args: ["-rn", "-sn", "--rcfile=.pylintrc", "--fail-on=I", "--spelling-dict=en"] - # stages: [manual] - # - id: mypy - # name: mypy - # entry: mypy - # language: system - # types: [python] - # args: ["--non-interactive"] + - id: ruff + args: [ --fix ] + - id: ruff-format # - repo: https://github.com/myint/rstcheck # rev: "v5.0.0" # hooks: @@ -86,15 +33,3 @@ repos: - id: prettier args: [--prose-wrap=always, --print-width=88] exclude: \.(po|pot|yml|yaml)$ - - repo: https://github.com/DanielNoord/pydocstringformatter - rev: v0.7.3 - hooks: - - id: pydocstringformatter - exclude: \.(po|pot|yml|yaml)$ - args: - [ - --style=numpydoc, - --no-numpydoc-name-type-spacing, - --no-final-period, - --no-capitalize-first-letter, - ] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index d37233a0c1..0000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,21 +0,0 @@ -include README.rst -include LICENSE -include requirements.txt -include discord/bin/*.dll -include discord/banner.txt -include discord/ibanner.txt -include discord/py.typed - -prune .github -prune docs -prune examples -prune tests -exclude discord/bin/COPYING -exclude .flake8 -exclude .gitignore -exclude .pre-commit-config.yaml -exclude .prettierrc -exclude .readthedocs.yml -exclude CHANGELOG.md -exclude FUNDING.yml -exclude requirements-dev.txt diff --git a/discord/__init__.py b/discord/__init__.py index d6031ce3ac..77604b0eff 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -19,7 +19,7 @@ # We need __version__ to be imported first # isort: off -from ._version import * +from .__version import * # isort: on diff --git a/discord/__main__.py b/discord/__main__.py index ed34bdf42a..ebf30cd7c9 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -36,16 +36,10 @@ def show_version() -> None: - entries = [ - "- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format( - sys.version_info - ) - ] + entries = ["- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info)] version_info = discord.version_info - entries.append( - "- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info) - ) + entries.append("- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info)) if version_info.releaselevel != "final": version = importlib.metadata.version("py-cord") if version: @@ -299,9 +293,7 @@ def newcog(parser, args) -> None: def add_newbot_args(subparser: argparse._SubParsersAction) -> None: - parser = subparser.add_parser( - "newbot", help="creates a command bot project quickly" - ) + parser = subparser.add_parser("newbot", help="creates a command bot project quickly") parser.set_defaults(func=newbot) parser.add_argument("name", help="the bot project name") @@ -311,12 +303,8 @@ def add_newbot_args(subparser: argparse._SubParsersAction) -> None: nargs="?", default=Path.cwd(), ) - parser.add_argument( - "--prefix", help="the bot prefix (default: $)", default="$", metavar="" - ) - parser.add_argument( - "--sharded", help="whether to use AutoShardedBot", action="store_true" - ) + parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="") + parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true") parser.add_argument( "--no-git", help="do not create a .gitignore file", @@ -347,18 +335,12 @@ def add_newcog_args(subparser: argparse._SubParsersAction) -> None: help="whether to hide all commands in the cog", action="store_true", ) - parser.add_argument( - "--full", help="add all special methods as well", action="store_true" - ) + parser.add_argument("--full", help="add all special methods as well", action="store_true") def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]: - parser = argparse.ArgumentParser( - prog="discord", description="Tools for helping with Pycord" - ) - parser.add_argument( - "-v", "--version", action="store_true", help="shows the library version" - ) + parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with Pycord") + parser.add_argument("-v", "--version", action="store_true", help="shows the library version") parser.set_defaults(func=core) subparser = parser.add_subparsers(dest="subcommand", title="subcommands") diff --git a/discord/_version.py b/discord/__version.py similarity index 52% rename from discord/_version.py rename to discord/__version.py index ba68799dfe..88b3df1b5c 100644 --- a/discord/_version.py +++ b/discord/__version.py @@ -28,7 +28,6 @@ import datetime import re import warnings -from importlib.metadata import PackageNotFoundError, version from typing_extensions import TypedDict @@ -37,26 +36,7 @@ from typing import Literal, NamedTuple from .utils import deprecated - -try: - __version__ = version("py-cord") -except PackageNotFoundError: - # Package is not installed - try: - from setuptools_scm import get_version # type: ignore[import] - - __version__ = get_version() - except ImportError: - # setuptools_scm is not installed - __version__ = "0.0.0" - warnings.warn( - ( - "Package is not installed, and setuptools_scm is not installed. " - f"As a fallback, {__name__}.__version__ will be set to {__version__}" - ), - RuntimeWarning, - stacklevel=2, - ) +from ._version import __version__, __version_tuple__ class AdvancedVersionInfo(TypedDict): @@ -70,7 +50,7 @@ class VersionInfo(NamedTuple): major: int minor: int micro: int - releaselevel: Literal["alpha", "beta", "candidate", "final"] + releaselevel: Literal["alpha", "beta", "candidate", "final", "dev"] # We can't set instance attributes on a NamedTuple, so we have to use a # global variable to store the advanced version info. @@ -85,7 +65,7 @@ def advanced(self, value: object) -> None: @property @deprecated("releaselevel", "2.4") - def release_level(self) -> Literal["alpha", "beta", "candidate", "final"]: + def release_level(self) -> Literal["alpha", "beta", "candidate", "final", "dev"]: return self.releaselevel @property @@ -109,49 +89,70 @@ def date(self) -> datetime.date | None: return self.advanced["date"] -version_regex = re.compile( - r"^(?P\d+)(?:\.(?P\d+))?(?:\.(?P\d+))?" - r"(?:(?Prc|a|b)(?P\d+))?" - r"(?:\.dev(?P\d+))?" - r"(?:\+(?:(?:g(?P[a-fA-F0-9]{4,40})(?:\.d(?P\d{4}\d{2}\d{2})|))|d(?P\d{4}\d{2}\d{2})))?$" -) -version_match = version_regex.match(__version__) -if version_match is None: - raise RuntimeError(f"Invalid version string: {__version__}") -raw_info = version_match.groupdict() - -level_info: Literal["alpha", "beta", "candidate", "final"] - -if raw_info["level"] == "a": - level_info = "alpha" -elif raw_info["level"] == "b": - level_info = "beta" -elif raw_info["level"] == "rc": - level_info = "candidate" -elif raw_info["level"] is None: - level_info = "final" -else: - raise RuntimeError("Invalid release level") - -if (raw_date := raw_info["date"] or raw_info["date1"]) is not None: - date_info = datetime.date( - int(raw_date[:4]), - int(raw_date[4:6]), - int(raw_date[6:]), - ) -else: +def parse_version_tuple(version_tuple): + """Parse setuptools-scm version tuple into components.""" + major = version_tuple[0] if len(version_tuple) > 0 else 0 + minor = version_tuple[1] if len(version_tuple) > 1 else 0 + micro = 0 + releaselevel = "final" + serial = 0 + build = None + commit = None date_info = None + # Handle additional components + for i, component in enumerate(version_tuple[2:], start=2): + if isinstance(component, str): + # Parse development/pre-release info + if component.startswith("dev"): + releaselevel = "dev" # Keep dev as its own category + serial = int(component[3:]) if len(component) > 3 else 0 + elif component.startswith("a"): + releaselevel = "alpha" + serial = int(component[1:]) if len(component) > 1 else 0 + elif component.startswith("b"): + releaselevel = "beta" + serial = int(component[1:]) if len(component) > 1 else 0 + elif component.startswith("rc"): + releaselevel = "candidate" + serial = int(component[2:]) if len(component) > 2 else 0 + elif component.startswith("g") and "." in component: + # Parse git info like 'g901fb98.d20250526' + parts = component.split(".") + if parts[0].startswith("g"): + commit = parts[0][1:] # Remove 'g' prefix + if len(parts) > 1 and parts[1].startswith("d"): + date_str = parts[1][1:] # Remove 'd' prefix + if len(date_str) == 8: + date_info = datetime.date(int(date_str[:4]), int(date_str[4:6]), int(date_str[6:8])) + elif isinstance(component, int) and i == 2: + micro = component + + return { + "major": major, + "minor": minor, + "micro": micro, + "releaselevel": releaselevel, + "serial": serial, + "build": build, + "commit": commit, + "date": date_info, + } + + +# Parse the version tuple +parsed = parse_version_tuple(__version_tuple__) + version_info: VersionInfo = VersionInfo( - major=int(raw_info["major"] or 0) or None, - minor=int(raw_info["minor"] or 0) or None, - micro=int(raw_info["patch"] or 0) or None, - releaselevel=level_info, + major=parsed["major"], + minor=parsed["minor"], + micro=parsed["micro"], + releaselevel=parsed["releaselevel"], ) _advanced = AdvancedVersionInfo( - serial=raw_info["serial"], - build=int(raw_info["build"] or 0) or None, - commit=raw_info["commit"], - date=date_info, + serial=parsed["serial"], + build=parsed["build"], + commit=parsed["commit"], + date=parsed["date"], ) diff --git a/discord/abc.py b/discord/abc.py index 4b99e49034..92c6b2b67d 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -98,18 +98,14 @@ from .ui.view import View from .user import ClientUser - PartialMessageableChannel = Union[ - TextChannel, VoiceChannel, StageChannel, Thread, DMChannel, PartialMessageable - ] + PartialMessageableChannel = Union[TextChannel, VoiceChannel, StageChannel, Thread, DMChannel, PartialMessageable] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] MISSING = utils.MISSING -async def _single_delete_strategy( - messages: Iterable[Message], *, reason: str | None = None -): +async def _single_delete_strategy(messages: Iterable[Message], *, reason: str | None = None): for m in messages: await m.delete(reason=reason) @@ -339,9 +335,7 @@ class GuildChannel: if TYPE_CHECKING: - def __init__( - self, *, state: ConnectionState, guild: Guild, data: dict[str, Any] - ): ... + def __init__(self, *, state: ConnectionState, guild: Guild, data: dict[str, Any]): ... def __str__(self) -> str: return self.name @@ -366,9 +360,7 @@ async def _move( http = self._state.http bucket = self._sorting_bucket - channels: list[GuildChannel] = [ - c for c in self.guild.channels if c._sorting_bucket == bucket - ] + channels: list[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] channels.sort(key=lambda c: c.position) @@ -395,9 +387,7 @@ async def _move( await http.bulk_channel_update(self.guild.id, payload, reason=reason) - async def _edit( - self, options: dict[str, Any], reason: str | None - ) -> ChannelPayload | None: + async def _edit(self, options: dict[str, Any], reason: str | None) -> ChannelPayload | None: try: parent = options.pop("category") except KeyError: @@ -411,9 +401,7 @@ async def _edit( pass try: - options["default_thread_rate_limit_per_user"] = options.pop( - "default_thread_slowmode_delay" - ) + options["default_thread_rate_limit_per_user"] = options.pop("default_thread_slowmode_delay") except KeyError: pass @@ -423,9 +411,7 @@ async def _edit( pass try: - options["available_tags"] = [ - tag.to_dict() for tag in options.pop("available_tags") - ] + options["available_tags"] = [tag.to_dict() for tag in options.pop("available_tags")] except KeyError: pass @@ -452,18 +438,14 @@ async def _edit( if lock_permissions: category = self.guild.get_channel(parent_id) if category: - options["permission_overwrites"] = [ - c._asdict() for c in category._overwrites - ] + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] options["parent_id"] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) if category: - options["permission_overwrites"] = [ - c._asdict() for c in category._overwrites - ] + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] else: await self._move( position, @@ -477,21 +459,14 @@ async def _edit( perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument( - "Expected PermissionOverwrite received" - f" {perm.__class__.__name__}" - ) + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") allow, deny = perm.pair() payload = { "allow": allow.value, "deny": deny.value, "id": target.id, - "type": ( - _Overwrites.ROLE - if isinstance(target, Role) - else _Overwrites.MEMBER - ), + "type": (_Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER), } perms.append(payload) @@ -511,33 +486,23 @@ async def _edit( except KeyError: pass else: - if isinstance( - default_reaction_emoji, _EmojiTag - ): # GuildEmoji, PartialEmoji + if isinstance(default_reaction_emoji, _EmojiTag): # GuildEmoji, PartialEmoji default_reaction_emoji = default_reaction_emoji._to_partial() elif isinstance(default_reaction_emoji, int): - default_reaction_emoji = PartialEmoji( - name=None, id=default_reaction_emoji - ) + default_reaction_emoji = PartialEmoji(name=None, id=default_reaction_emoji) elif isinstance(default_reaction_emoji, str): default_reaction_emoji = PartialEmoji.from_str(default_reaction_emoji) elif default_reaction_emoji is None: pass else: - raise InvalidArgument( - "default_reaction_emoji must be of type: GuildEmoji | int | str | None" - ) + raise InvalidArgument("default_reaction_emoji must be of type: GuildEmoji | int | str | None") options["default_reaction_emoji"] = ( - default_reaction_emoji._to_forum_reaction_payload() - if default_reaction_emoji - else None + default_reaction_emoji._to_forum_reaction_payload() if default_reaction_emoji else None ) if options: - return await self._state.http.edit_channel( - self.id, reason=reason, **options - ) + return await self._state.http.edit_channel(self.id, reason=reason, **options) def _fill_overwrites(self, data: GuildChannelPayload) -> None: self._overwrites = [] @@ -752,9 +717,7 @@ def permissions_for(self, obj: Member | Role, /) -> Permissions: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite( - allow=maybe_everyone.allow, deny=maybe_everyone.deny - ) + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) except IndexError: pass @@ -785,9 +748,7 @@ def permissions_for(self, obj: Member | Role, /) -> Permissions: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite( - allow=maybe_everyone.allow, deny=maybe_everyone.deny - ) + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) remaining_overwrites = self._overwrites[1:] else: remaining_overwrites = self._overwrites @@ -868,9 +829,7 @@ async def set_permissions( **permissions: bool, ) -> None: ... - async def set_permissions( - self, target, *, overwrite=MISSING, reason=None, **permissions - ): + async def set_permissions(self, target, *, overwrite=MISSING, reason=None, **permissions): r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -899,8 +858,7 @@ async def set_permissions( Setting allow and deny: :: - await message.channel.set_permissions(message.author, read_messages=True, - send_messages=False) + await message.channel.set_permissions(message.author, read_messages=True, send_messages=False) Deleting overwrites :: @@ -964,9 +922,7 @@ async def set_permissions( await http.delete_channel_permissions(self.id, target.id, reason=reason) elif isinstance(overwrite, PermissionOverwrite): (allow, deny) = overwrite.pair() - await http.edit_channel_permissions( - self.id, target.id, allow.value, deny.value, perm_type, reason=reason - ) + await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) else: raise InvalidArgument("Invalid overwrite type provided.") @@ -982,18 +938,14 @@ async def _clone_impl( base_attrs["name"] = name or self.name guild_id = self.guild.id cls = self.__class__ - data = await self._state.http.create_channel( - guild_id, self.type.value, reason=reason, **base_attrs - ) + data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) obj = cls(state=self._state, guild=self.guild, data=data) # temporarily add it to the cache self.guild._channels[obj.id] = obj # type: ignore return obj - async def clone( - self: GCH, *, name: str | None = None, reason: str | None = None - ) -> GCH: + async def clone(self: GCH, *, name: str | None = None, reason: str | None = None) -> GCH: """|coro| Clones this channel. This creates a channel with the same properties @@ -1136,9 +1088,7 @@ async def move(self, **kwargs) -> None: before, after = kwargs.get("before"), kwargs.get("after") offset = kwargs.get("offset", 0) if sum(bool(a) for a in (beginning, end, before, after)) > 1: - raise InvalidArgument( - "Only one of [before, after, end, beginning] can be used." - ) + raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") bucket = self._sorting_bucket parent_id = kwargs.get("category", MISSING) @@ -1146,15 +1096,11 @@ async def move(self, **kwargs) -> None: if parent_id not in (MISSING, None): parent_id = parent_id.id channels = [ - ch - for ch in self.guild.channels - if ch._sorting_bucket == bucket and ch.category_id == parent_id + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id ] else: channels = [ - ch - for ch in self.guild.channels - if ch._sorting_bucket == bucket and ch.category_id == self.category_id + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id ] channels.sort(key=lambda c: (c.position, c.id)) @@ -1174,9 +1120,7 @@ async def move(self, **kwargs) -> None: elif before: index = next((i for i, c in enumerate(channels) if c.id == before.id), None) elif after: - index = next( - (i + 1 for i, c in enumerate(channels) if c.id == after.id), None - ) + index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) if index is None: raise InvalidArgument("Could not resolve appropriate move position") @@ -1191,9 +1135,7 @@ async def move(self, **kwargs) -> None: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) - await self._state.http.bulk_channel_update( - self.guild.id, payload, reason=reason - ) + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) async def create_invite( self, @@ -1311,10 +1253,7 @@ async def invites(self) -> list[Invite]: state = self._state data = await state.http.invites_from_channel(self.id) guild = self.guild - return [ - Invite(state=state, data=invite, channel=self, guild=guild) - for invite in data - ] + return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] class Messageable: @@ -1554,18 +1493,14 @@ async def send( content = str(content) if content is not None else None if embed is not None and embeds is not None: - raise InvalidArgument( - "cannot pass both embed and embeds parameter to send()" - ) + raise InvalidArgument("cannot pass both embed and embeds parameter to send()") if embed is not None: embed = embed.to_dict() elif embeds is not None: if len(embeds) > 10: - raise InvalidArgument( - "embeds parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") embeds = [embed.to_dict() for embed in embeds] flags = MessageFlags( @@ -1577,9 +1512,7 @@ async def send( stickers = [sticker.id for sticker in stickers] if allowed_mentions is None: - allowed_mentions = ( - state.allowed_mentions and state.allowed_mentions.to_dict() - ) + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() elif state.allowed_mentions is not None: allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() else: @@ -1604,15 +1537,12 @@ async def send( ) except AttributeError: raise InvalidArgument( - "reference parameter must be Message, MessageReference, or" - " PartialMessage" + "reference parameter must be Message, MessageReference, or PartialMessage" ) from None if view: if not hasattr(view, "__discord_ui_view__"): - raise InvalidArgument( - f"view parameter must be View not {view.__class__!r}" - ) + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() else: @@ -1630,16 +1560,12 @@ async def send( files = [file] elif files is not None: if len(files) > 10: - raise InvalidArgument( - "files parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") if files is not None: - flags = flags + MessageFlags( - is_voice_message=any(isinstance(f, VoiceMessage) for f in files) - ) + flags = flags + MessageFlags(is_voice_message=any(isinstance(f, VoiceMessage) for f in files)) try: data = await state.http.send_files( channel.id, @@ -1713,7 +1639,7 @@ def typing(self) -> Typing: # simulate something heavy await asyncio.sleep(10) - await channel.send('done!') + await channel.send("done!") """ return Typing(self) @@ -1811,15 +1737,10 @@ def can_send(self, *objects) -> bool: if obj is None: permission = mapping["Message"] else: - permission = ( - mapping.get(type(obj).__name__) or mapping[obj.__name__] - ) + permission = mapping.get(type(obj).__name__) or mapping[obj.__name__] if type(obj).__name__ == "GuildEmoji": - if ( - obj._to_partial().is_unicode_emoji - or obj.guild_id == channel.guild.id - ): + if obj._to_partial().is_unicode_emoji or obj.guild_id == channel.guild.id: continue elif type(obj).__name__ == "GuildSticker": if obj.guild_id == channel.guild.id: diff --git a/discord/activity.py b/discord/activity.py index 81128a5fb0..9cda6bc3f2 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -129,9 +129,7 @@ def created_at(self) -> datetime.datetime | None: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp( - self._created_at / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) def to_dict(self) -> ActivityPayload: raise NotImplementedError @@ -237,18 +235,12 @@ def __init__(self, **kwargs): activity_type = kwargs.pop("type", -1) self.type: ActivityType = ( - activity_type - if isinstance(activity_type, ActivityType) - else try_enum(ActivityType, activity_type) - ) - self.name: str | None = kwargs.pop( - "name", "Custom Status" if self.type == ActivityType.custom else None + activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) ) + self.name: str | None = kwargs.pop("name", "Custom Status" if self.type == ActivityType.custom else None) emoji = kwargs.pop("emoji", None) - self.emoji: PartialEmoji | None = ( - PartialEmoji.from_dict(emoji) if emoji is not None else None - ) + self.emoji: PartialEmoji | None = PartialEmoji.from_dict(emoji) if emoji is not None else None def __repr__(self) -> str: attrs = ( @@ -398,18 +390,14 @@ def type(self) -> ActivityType: def start(self) -> datetime.datetime | None: """When the user started playing this game in UTC, if applicable.""" if self._start: - return datetime.datetime.fromtimestamp( - self._start / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc) return None @property def end(self) -> datetime.datetime | None: """When the user will stop playing this game in UTC, if applicable.""" if self._end: - return datetime.datetime.fromtimestamp( - self._end / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc) return None def __str__(self) -> str: @@ -536,11 +524,7 @@ def to_dict(self) -> dict[str, Any]: return ret def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, Streaming) - and other.name == self.name - and other.url == self.url - ) + return isinstance(other, Streaming) and other.name == self.name and other.url == self.url def __hash__(self) -> int: return hash(self.name) @@ -605,9 +589,7 @@ def created_at(self) -> datetime.datetime | None: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp( - self._created_at / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) @property def colour(self) -> Colour: @@ -658,10 +640,7 @@ def __str__(self) -> str: return "Spotify" def __repr__(self) -> str: - return ( - "" - ) + return f"" @property def title(self) -> str: @@ -712,16 +691,12 @@ def track_url(self) -> str: @property def start(self) -> datetime.datetime: """When the user started playing this song in UTC.""" - return datetime.datetime.fromtimestamp( - self._timestamps["start"] / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc) @property def end(self) -> datetime.datetime: """When the user will stop playing this song in UTC.""" - return datetime.datetime.fromtimestamp( - self._timestamps["end"] / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc) @property def duration(self) -> datetime.timedelta: @@ -769,9 +744,7 @@ class CustomActivity(BaseActivity): __slots__ = ("name", "emoji", "state") - def __init__( - self, name: str | None, *, emoji: PartialEmoji | None = None, **extra: Any - ): + def __init__(self, name: str | None, *, emoji: PartialEmoji | None = None, **extra: Any): super().__init__(**extra) self.name: str | None = name self.state: str | None = extra.pop("state", name) @@ -788,10 +761,7 @@ def __init__( elif isinstance(emoji, PartialEmoji): self.emoji = emoji else: - raise TypeError( - "Expected str, PartialEmoji, or None, received" - f" {type(emoji)!r} instead." - ) + raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.") @property def type(self) -> ActivityType: @@ -819,11 +789,7 @@ def to_dict(self) -> dict[str, Any]: return o def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, CustomActivity) - and other.name == self.name - and other.emoji == self.emoji - ) + return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji def __hash__(self) -> int: return hash((self.name, str(self.emoji))) @@ -872,10 +838,6 @@ def create_activity(data: ActivityPayload | None) -> ActivityTypes | None: # the url won't be None here return Streaming(**data) # type: ignore return Activity(**data) - elif ( - game_type is ActivityType.listening - and "sync_id" in data - and "session_id" in data - ): + elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: return Spotify(**data) return Activity(**data) diff --git a/discord/appinfo.py b/discord/appinfo.py index 034b1bb158..6009cb7c80 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -202,29 +202,19 @@ def __init__(self, state: ConnectionState, data: AppInfoPayload): self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") - self.primary_sku_id: int | None = utils._get_as_snowflake( - data, "primary_sku_id" - ) + self.primary_sku_id: int | None = utils._get_as_snowflake(data, "primary_sku_id") self.slug: str | None = data.get("slug") self._cover_image: str | None = data.get("cover_image") self.terms_of_service_url: str | None = data.get("terms_of_service_url") self.privacy_policy_url: str | None = data.get("privacy_policy_url") self.approximate_guild_count: int | None = data.get("approximate_guild_count") - self.approximate_user_install_count: int | None = data.get( - "approximate_user_install_count" - ) + self.approximate_user_install_count: int | None = data.get("approximate_user_install_count") self.redirect_uris: list[str] | None = data.get("redirect_uris", []) - self.interactions_endpoint_url: str | None = data.get( - "interactions_endpoint_url" - ) - self.role_connections_verification_url: str | None = data.get( - "role_connections_verification_url" - ) + self.interactions_endpoint_url: str | None = data.get("interactions_endpoint_url") + self.role_connections_verification_url: str | None = data.get("role_connections_verification_url") install_params = data.get("install_params") - self.install_params: AppInstallParams | None = ( - AppInstallParams(install_params) if install_params else None - ) + self.install_params: AppInstallParams | None = AppInstallParams(install_params) if install_params else None self.tags: list[str] | None = data.get("tags", []) self.custom_install_url: str | None = data.get("custom_install_url") diff --git a/discord/application_role_connection.py b/discord/application_role_connection.py index 3e22d55cc1..e8ac4d5b99 100644 --- a/discord/application_role_connection.py +++ b/discord/application_role_connection.py @@ -102,9 +102,7 @@ def __str__(self): return self.name @classmethod - def from_dict( - cls, data: ApplicationRoleConnectionMetadataPayload - ) -> ApplicationRoleConnectionMetadata: + def from_dict(cls, data: ApplicationRoleConnectionMetadataPayload) -> ApplicationRoleConnectionMetadata: return cls( type=try_enum(ApplicationRoleConnectionMetadataType, data["type"]), key=data["key"], diff --git a/discord/asset.py b/discord/asset.py index eeb78e2e9c..3bc04faa7c 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -183,9 +183,7 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: ) @classmethod - def _from_avatar_decoration( - cls, state, user_id: int, avatar_decoration: str - ) -> Asset: + def _from_avatar_decoration(cls, state, user_id: int, avatar_decoration: str) -> Asset: animated = avatar_decoration.startswith("a_") endpoint = ( "avatar-decoration-presets" @@ -200,9 +198,7 @@ def _from_avatar_decoration( ) @classmethod - def _from_guild_avatar( - cls, state, guild_id: int, member_id: int, avatar: str - ) -> Asset: + def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -213,9 +209,7 @@ def _from_guild_avatar( ) @classmethod - def _from_guild_banner( - cls, state, guild_id: int, member_id: int, banner: str - ) -> Asset: + def _from_guild_banner(cls, state, guild_id: int, member_id: int, banner: str) -> Asset: animated = banner.startswith("a_") format = "gif" if animated else "png" return cls( @@ -290,9 +284,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: ) @classmethod - def _from_scheduled_event_image( - cls, state, event_id: int, cover_hash: str - ) -> Asset: + def _from_scheduled_event_image(cls, state, event_id: int, cover_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/guild-events/{event_id}/{cover_hash}.png", @@ -366,22 +358,16 @@ def replace( if format is not MISSING: if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument( - f"format must be one of {VALID_ASSET_FORMATS}" - ) + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") url = url.with_path(f"{path}.{format}") elif static_format is MISSING: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument( - f"format must be one of {VALID_STATIC_FORMATS}" - ) + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") url = url.with_path(f"{path}.{format}") if static_format is not MISSING and not self._animated: if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument( - f"static_format must be one of {VALID_STATIC_FORMATS}" - ) + raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}") url = url.with_path(f"{path}.{static_format}") if size is not MISSING: diff --git a/discord/audit_logs.py b/discord/audit_logs.py index e2ff277dfe..794d1ccfdf 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -78,33 +78,25 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel( - entry: AuditLogEntry, data: Snowflake | None -) -> abc.GuildChannel | Object | None: +def _transform_channel(entry: AuditLogEntry, data: Snowflake | None) -> abc.GuildChannel | Object | None: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_channels( - entry: AuditLogEntry, data: list[Snowflake] | None -) -> list[abc.GuildChannel | Object] | None: +def _transform_channels(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[abc.GuildChannel | Object] | None: if data is None: return None return [_transform_channel(entry, channel) for channel in data] -def _transform_roles( - entry: AuditLogEntry, data: list[Snowflake] | None -) -> list[Role | Object] | None: +def _transform_roles(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[Role | Object] | None: if data is None: return None return [entry.guild.get_role(int(r)) or Object(id=r) for r in data] -def _transform_member_id( - entry: AuditLogEntry, data: Snowflake | None -) -> Member | User | None: +def _transform_member_id(entry: AuditLogEntry, data: Snowflake | None) -> Member | User | None: if data is None: return None return entry._get_member(int(data)) @@ -153,9 +145,7 @@ def _transform_avatar(entry: AuditLogEntry, data: str | None) -> Asset | None: return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore -def _transform_scheduled_event_image( - entry: AuditLogEntry, data: str | None -) -> Asset | None: +def _transform_scheduled_event_image(entry: AuditLogEntry, data: str | None) -> Asset | None: if data is None: return None return Asset._from_scheduled_event_image(entry._state, entry._target_id, data) @@ -182,18 +172,14 @@ def _transform(entry: AuditLogEntry, data: int) -> T: return _transform -def _transform_type( - entry: AuditLogEntry, data: int -) -> enums.ChannelType | enums.StickerType: +def _transform_type(entry: AuditLogEntry, data: int) -> enums.ChannelType | enums.StickerType: if entry.action.name.startswith("sticker_"): return enums.try_enum(enums.StickerType, data) else: return enums.try_enum(enums.ChannelType, data) -def _transform_actions( - entry: AuditLogEntry, data: list[AutoModActionPayload] | None -) -> list[AutoModAction] | None: +def _transform_actions(entry: AuditLogEntry, data: list[AutoModActionPayload] | None) -> list[AutoModAction] | None: if data is None: return None else: @@ -309,7 +295,11 @@ def __init__( "$add_allow_list", ]: self._handle_trigger_metadata( - self.before, self.after, entry, elem["new_value"], attr # type: ignore + self.before, + self.after, + entry, + elem["new_value"], + attr, # type: ignore ) continue elif attr in [ @@ -318,7 +308,11 @@ def __init__( "$remove_allow_list", ]: self._handle_trigger_metadata( - self.after, self.before, entry, elem["new_value"], attr # type: ignore + self.after, + self.before, + entry, + elem["new_value"], + attr, # type: ignore ) continue @@ -343,15 +337,10 @@ def __init__( if attr == "location" and hasattr(self.before, "location_type"): from .scheduled_events import ScheduledEventLocation - if ( - self.before.location_type - is enums.ScheduledEventLocationType.external - ): + if self.before.location_type is enums.ScheduledEventLocationType.external: before = ScheduledEventLocation(state=state, value=before) elif hasattr(self.before, "channel"): - before = ScheduledEventLocation( - state=state, value=self.before.channel - ) + before = ScheduledEventLocation(state=state, value=self.before.channel) setattr(self.before, attr, before) @@ -366,15 +355,10 @@ def __init__( if attr == "location" and hasattr(self.after, "location_type"): from .scheduled_events import ScheduledEventLocation - if ( - self.after.location_type - is enums.ScheduledEventLocationType.external - ): + if self.after.location_type is enums.ScheduledEventLocationType.external: after = ScheduledEventLocation(state=state, value=after) elif hasattr(self.after, "channel"): - after = ScheduledEventLocation( - state=state, value=self.after.channel - ) + after = ScheduledEventLocation(state=state, value=self.after.channel) setattr(self.after, attr, after) @@ -498,9 +482,7 @@ class AuditLogEntry(Hashable): which actions have this field filled out. """ - def __init__( - self, *, users: dict[int, User], data: AuditLogEntryPayload, guild: Guild - ): + def __init__(self, *, users: dict[int, User], data: AuditLogEntryPayload, guild: Guild): self._state = guild._state self.guild = guild self._users = users @@ -520,38 +502,27 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: self.extra: _AuditLogProxyMemberPrune = type( "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} )() - elif ( - self.action is enums.AuditLogAction.member_move - or self.action is enums.AuditLogAction.message_delete - ): + elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: channel_id = int(self.extra["channel_id"]) elems = { "count": int(self.extra["count"]), - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), } - self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)() elif self.action is enums.AuditLogAction.member_disconnect: # The member disconnect action has a dict with some information elems = { "count": int(self.extra["count"]), } - self.extra: _AuditLogProxyMemberDisconnect = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)() elif self.action.name.endswith("pin"): # the pin actions have a dict with some information channel_id = int(self.extra["channel_id"]) elems = { - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), "message_id": int(self.extra["message_id"]), } - self.extra: _AuditLogProxyPinAction = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)() elif self.action.name.startswith("overwrite_"): # the overwrite_ actions have a dict with some information instance_id = int(self.extra["id"]) @@ -566,13 +537,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: self.extra: Role = role elif self.action.name.startswith("stage_instance"): channel_id = int(self.extra["channel_id"]) - elems = { - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id) - } - self.extra: _AuditLogProxyStageInstanceAction = type( - "_AuditLogProxy", (), elems - )() + elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)} + self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)() self.extra: ( _AuditLogProxyMemberPrune @@ -668,11 +634,7 @@ def _convert_target_role(self, target_id: int) -> Role | Object: def _convert_target_invite(self, target_id: int) -> Invite: # invites have target_id set to null # so figure out which change has the full invite data - changeset = ( - self.before - if self.action is enums.AuditLogAction.invite_delete - else self.after - ) + changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after fake_payload = { "max_age": changeset.max_age, @@ -704,7 +666,5 @@ def _convert_target_sticker(self, target_id: int) -> GuildSticker | Object: def _convert_target_thread(self, target_id: int) -> Thread | Object: return self.guild.get_thread(target_id) or Object(id=target_id) - def _convert_target_scheduled_event( - self, target_id: int - ) -> ScheduledEvent | Object: + def _convert_target_scheduled_event(self, target_id: int) -> ScheduledEvent | Object: return self.guild.get_scheduled_event(target_id) or Object(id=target_id) diff --git a/discord/automod.py b/discord/automod.py index 27c9e52f9f..f7e935eeec 100644 --- a/discord/automod.py +++ b/discord/automod.py @@ -290,9 +290,7 @@ def from_dict(cls, data: AutoModTriggerMetadataPayload): kwargs["regex_patterns"] = regex_patterns if (presets := data.get("presets")) is not None: - kwargs["presets"] = [ - try_enum(AutoModKeywordPresetType, wordset) for wordset in presets - ] + kwargs["presets"] = [try_enum(AutoModKeywordPresetType, wordset) for wordset in presets] if (allow_list := data.get("allow_list")) is not None: kwargs["allow_list"] = allow_list @@ -394,18 +392,10 @@ def __init__( self.guild_id: int = int(data["guild_id"]) self.name: str = data["name"] self.creator_id: int = int(data["creator_id"]) - self.event_type: AutoModEventType = try_enum( - AutoModEventType, data["event_type"] - ) - self.trigger_type: AutoModTriggerType = try_enum( - AutoModTriggerType, data["trigger_type"] - ) - self.trigger_metadata: AutoModTriggerMetadata = ( - AutoModTriggerMetadata.from_dict(data["trigger_metadata"]) - ) - self.actions: list[AutoModAction] = [ - AutoModAction.from_dict(d) for d in data["actions"] - ] + self.event_type: AutoModEventType = try_enum(AutoModEventType, data["event_type"]) + self.trigger_type: AutoModTriggerType = try_enum(AutoModTriggerType, data["trigger_type"]) + self.trigger_metadata: AutoModTriggerMetadata = AutoModTriggerMetadata.from_dict(data["trigger_metadata"]) + self.actions: list[AutoModAction] = [AutoModAction.from_dict(d) for d in data["actions"]] self.enabled: bool = data["enabled"] self.exempt_role_ids: list[int] = [int(r) for r in data["exempt_roles"]] self.exempt_channel_ids: list[int] = [int(c) for c in data["exempt_channels"]] @@ -438,10 +428,7 @@ def exempt_roles(self) -> list[Role | Object]: """ if self.guild is None: return [Object(role_id) for role_id in self.exempt_role_ids] - return [ - self.guild.get_role(role_id) or Object(role_id) - for role_id in self.exempt_role_ids - ] + return [self.guild.get_role(role_id) or Object(role_id) for role_id in self.exempt_role_ids] @cached_property def exempt_channels( @@ -454,10 +441,7 @@ def exempt_channels( """ if self.guild is None: return [Object(channel_id) for channel_id in self.exempt_channel_ids] - return [ - self.guild.get_channel(channel_id) or Object(channel_id) - for channel_id in self.exempt_channel_ids - ] + return [self.guild.get_channel(channel_id) or Object(channel_id) for channel_id in self.exempt_channel_ids] async def delete(self, reason: str | None = None) -> None: """|coro| @@ -476,9 +460,7 @@ async def delete(self, reason: str | None = None) -> None: HTTPException The operation failed. """ - await self._state.http.delete_auto_moderation_rule( - self.guild_id, self.id, reason=reason - ) + await self._state.http.delete_auto_moderation_rule(self.guild_id, self.id, reason=reason) async def edit( self, @@ -554,7 +536,5 @@ async def edit( payload["exempt_channels"] = [c.id for c in exempt_channels] if payload: - data = await http.edit_auto_moderation_rule( - self.guild_id, self.id, payload, reason=reason - ) + data = await http.edit_auto_moderation_rule(self.guild_id, self.id, payload, reason=reason) return AutoModRule(state=self._state, data=data) diff --git a/discord/banners.py b/discord/banners.py index e26cf43120..c94a2028de 100644 --- a/discord/banners.py +++ b/discord/banners.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import importlib.resources import logging @@ -14,20 +16,20 @@ import colorlog import colorlog.escape_codes -__all__: Sequence[str] = ('start_logging', 'print_banner') +__all__: Sequence[str] = ("start_logging", "print_banner") day_prefixes: dict[int, str] = { - 1: 'st', - 2: 'nd', - 3: 'rd', - 4: 'th', - 5: 'th', - 6: 'th', - 7: 'th', - 8: 'th', - 9: 'th', - 0: 'th', + 1: "st", + 2: "nd", + 3: "rd", + 4: "th", + 5: "th", + 6: "th", + 7: "th", + 8: "th", + 9: "th", + 0: "th", } @@ -41,7 +43,7 @@ def start_logging(flavor: None | int | str | dict, debug: bool = False): if isinstance(flavor, dict): logging.config.dictConfig(flavor) - if flavor.get('handler'): + if flavor.get("handler"): return flavor = None @@ -52,18 +54,18 @@ def start_logging(flavor: None | int | str | dict, debug: bool = False): colorlog.basicConfig( level=flavor, - format='%(log_color)s%(bold)s%(levelname)-1.1s%(thin)s %(asctime)23.23s %(bold)s%(name)s: ' - '%(thin)s%(message)s%(reset)s', + format="%(log_color)s%(bold)s%(levelname)-1.1s%(thin)s %(asctime)23.23s %(bold)s%(name)s: " + "%(thin)s%(message)s%(reset)s", stream=sys.stderr, log_colors={ - 'DEBUG': 'cyan', - 'INFO': 'green', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'red, bg_white', + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red, bg_white", }, ) - warnings.simplefilter('always', DeprecationWarning) + warnings.simplefilter("always", DeprecationWarning) logging.captureWarnings(True) @@ -73,27 +75,27 @@ def get_day_prefix(num: int) -> str: def print_banner( - bot_name: str = 'A bot', - module: str | None = 'pycord', + bot_name: str = "A bot", + module: str | None = "pycord", ): banners = importlib.resources.files(module) for trav in banners.iterdir(): - if trav.name == 'banner.txt': + if trav.name == "banner.txt": banner = trav.read_text() - elif trav.name == 'ibanner.txt': + elif trav.name == "ibanner.txt": info_banner = trav.read_text() - today = datetime.date.today() + today = datetime.datetime.now().date() args = { # the # prefix only works on Windows, and the - prefix only works on linux/unix systems - 'current_time': today.strftime(f'%B the %#d{get_day_prefix(today.day)} of %Y') - if os.name == 'nt' - else today.strftime(f'%B the %-d{get_day_prefix(today.day)} of %Y'), - 'py_version': platform.python_version(), - 'botname': bot_name, - 'version': __version__ + "current_time": today.strftime(f"%B the %#d{get_day_prefix(today.day)} of %Y") + if os.name == "nt" + else today.strftime(f"%B the %-d{get_day_prefix(today.day)} of %Y"), + "py_version": platform.python_version(), + "botname": bot_name, + "version": __version__, } args |= colorlog.escape_codes.escape_codes diff --git a/discord/bot.py b/discord/bot.py index 0f9b30480c..c83dbe78bf 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -109,9 +109,7 @@ def pending_application_commands(self): @property def commands(self) -> list[ApplicationCommand | Any]: commands = self.application_commands - if self._bot._supports_prefixed_commands and hasattr( - self._bot, "prefixed_commands" - ): + if self._bot._supports_prefixed_commands and hasattr(self._bot, "prefixed_commands"): commands += getattr(self._bot, "prefixed_commands") return commands @@ -139,10 +137,7 @@ def add_application_command(self, command: ApplicationCommand) -> None: command.guild_ids = self._bot.debug_guilds if self._bot.default_command_contexts and command.contexts is None: command.contexts = self._bot.default_command_contexts - if ( - self._bot.default_command_integration_types - and command.integration_types is None - ): + if self._bot.default_command_integration_types and command.integration_types is None: command.integration_types = self._bot.default_command_integration_types for cmd in self.pending_application_commands: @@ -152,9 +147,7 @@ def add_application_command(self, command: ApplicationCommand) -> None: break self._pending_application_commands.append(command) - def remove_application_command( - self, command: ApplicationCommand - ) -> ApplicationCommand | None: + def remove_application_command(self, command: ApplicationCommand) -> ApplicationCommand | None: """Remove an :class:`.ApplicationCommand` from the internal list of commands. @@ -221,9 +214,7 @@ def get_application_command( if guild_ids is not None and command.guild_ids != guild_ids: return return command - elif (names := name.split())[0] == command.name and isinstance( - command, SlashCommandGroup - ): + elif (names := name.split())[0] == command.name and isinstance(command, SlashCommandGroup): while len(names) > 1: command = get(commands, name=names.pop(0)) if not isinstance(command, SlashCommandGroup) or ( @@ -232,9 +223,7 @@ def get_application_command( return commands = command.subcommands command = get(commands, name=names.pop()) - if not isinstance(command, type) or ( - guild_ids is not None and command.guild_ids != guild_ids - ): + if not isinstance(command, type) or (guild_ids is not None and command.guild_ids != guild_ids): return return command @@ -277,11 +266,7 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: return True for i, subcommand in enumerate(cmd.subcommands): match_ = next( - ( - data - for data in match["options"] - if data["name"] == subcommand.name - ), + (data for data in match["options"] if data["name"] == subcommand.name), MISSING, ) if match_ is not MISSING and _check_command(subcommand, match_): @@ -313,28 +298,18 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: # The API considers False (autocomplete) and [] (choices) to be falsy values falsy_vals = (False, []) for opt in value: - cmd_vals = ( - [val.get(opt, MISSING) for val in as_dict[check]] - if check in as_dict - else [] - ) + cmd_vals = [val.get(opt, MISSING) for val in as_dict[check]] if check in as_dict else [] for i, val in enumerate(cmd_vals): if val in falsy_vals: cmd_vals[i] = MISSING - if match.get( - check, MISSING - ) is not MISSING and cmd_vals != [ + if match.get(check, MISSING) is not MISSING and cmd_vals != [ val.get(opt, MISSING) for val in match[check] ]: # We have a difference return True elif getattr(cmd, check, None) != match.get(check): # We have a difference - if ( - check == "default_permission" - and getattr(cmd, check) is True - and match.get(check) is None - ): + if check == "default_permission" and getattr(cmd, check) is True and match.get(check) is None: # This is a special case # TODO: Remove for perms v2 continue @@ -347,24 +322,16 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: if guild_id is None: pending = [cmd for cmd in cmds if cmd.guild_ids is None] else: - pending = [ - cmd - for cmd in cmds - if cmd.guild_ids is not None and guild_id in cmd.guild_ids - ] + pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids] registered_commands: list[interactions.ApplicationCommand] = [] if prefetched is not None: registered_commands = prefetched elif self._bot.user: if guild_id is None: - registered_commands = await self._bot.http.get_global_commands( - self._bot.user.id - ) + registered_commands = await self._bot.http.get_global_commands(self._bot.user.id) else: - registered_commands = await self._bot.http.get_guild_commands( - self._bot.user.id, guild_id - ) + registered_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id) registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands} # First let's check if the commands we have locally are the same as the ones on discord @@ -383,9 +350,7 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: ) else: # We have this command registered but it's the same - return_value.append( - {"command": cmd, "action": None, "id": int(match["id"])} - ) + return_value.append({"command": cmd, "action": None, "id": int(match["id"])}) # Now let's see if there are any commands on discord that we need to delete for cmd, value_ in registered_commands_dict.items(): @@ -488,12 +453,8 @@ async def register_commands( "edit": self._bot.http.edit_global_command, } - def _register( - method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs - ): - return registration_methods[method]( - self._bot.user and self._bot.user.id, *args, **kwargs - ) + def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs): + return registration_methods[method](self._bot.user and self._bot.user.id, *args, **kwargs) else: pending = list( @@ -509,12 +470,8 @@ def _register( "edit": self._bot.http.edit_guild_command, } - def _register( - method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs - ): - return registration_methods[method]( - self._bot.user and self._bot.user.id, guild_id, *args, **kwargs - ) + def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs): + return registration_methods[method](self._bot.user and self._bot.user.id, guild_id, *args, **kwargs) def register( method: Literal["bulk", "upsert", "delete", "edit"], @@ -525,10 +482,7 @@ def register( ): if kwargs.pop("_log", True): if method == "bulk": - _log.debug( - f"Bulk updating commands {[c['name'] for c in args[0]]} for" - f" guild {guild_id}" - ) + _log.debug(f"Bulk updating commands {[c['name'] for c in args[0]]} for guild {guild_id}") elif method == "upsert": _log.debug(f"Creating command {cmd_name} for guild {guild_id}") # type: ignore elif method == "edit": @@ -543,25 +497,17 @@ def register( prefetched_commands: list[interactions.ApplicationCommand] = [] if self._bot.user: if guild_id is None: - prefetched_commands = await self._bot.http.get_global_commands( - self._bot.user.id - ) + prefetched_commands = await self._bot.http.get_global_commands(self._bot.user.id) else: - prefetched_commands = await self._bot.http.get_guild_commands( - self._bot.user.id, guild_id - ) - desynced = await self.get_desynced_commands( - guild_id=guild_id, prefetched=prefetched_commands - ) + prefetched_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id) + desynced = await self.get_desynced_commands(guild_id=guild_id, prefetched=prefetched_commands) for cmd in desynced: if cmd["action"] == "delete": pending_actions.append( { "action": "delete" if delete_existing else None, - "command": collections.namedtuple("Command", ["name"])( - name=cmd["command"] - ), + "command": collections.namedtuple("Command", ["name"])(name=cmd["command"]), "id": cmd["id"], } ) @@ -594,15 +540,9 @@ def register( ) else: raise ValueError(f"Unknown action: {cmd['action']}") - filtered_no_action = list( - filter(lambda c: c["action"] is not None, pending_actions) - ) - filtered_deleted = list( - filter(lambda a: a["action"] != "delete", pending_actions) - ) - if method == "bulk" or ( - method == "auto" and len(filtered_deleted) == len(pending) - ): + filtered_no_action = list(filter(lambda c: c["action"] is not None, pending_actions)) + filtered_deleted = list(filter(lambda a: a["action"] != "delete", pending_actions)) + if method == "bulk" or (method == "auto" and len(filtered_deleted) == len(pending)): # Either the method is bulk or all the commands need to be modified, so we can just do a bulk upsert data = [cmd["command"].to_dict() for cmd in filtered_deleted] # If there's nothing to update, don't bother @@ -654,13 +594,9 @@ def register( if method != "bulk": if self._bot.user: if guild_id is None: - registered = await self._bot.http.get_global_commands( - self._bot.user.id - ) + registered = await self._bot.http.get_global_commands(self._bot.user.id) else: - registered = await self._bot.http.get_guild_commands( - self._bot.user.id, guild_id - ) + registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id) else: data = [cmd.to_dict() for cmd in pending] registered = await register("bulk", data, guild_id=guild_id) @@ -672,10 +608,7 @@ def register( type=i.get("type"), ) if not cmd: - raise ValueError( - f"Registered command {i['name']}, type {i.get('type')} not found in" - " pending commands" - ) + raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands") cmd.id = i["id"] self._application_commands[cmd.id] = cmd @@ -765,11 +698,7 @@ async def on_connect(): if check_guilds is not None: cmd_guild_ids.extend(check_guilds) for guild_id in set(cmd_guild_ids): - guild_commands = [ - cmd - for cmd in commands - if cmd.guild_ids is not None and guild_id in cmd.guild_ids - ] + guild_commands = [cmd for cmd in commands if cmd.guild_ids is not None and guild_id in cmd.guild_ids] app_cmds = await self.register_commands( guild_commands, guild_id=guild_id, @@ -807,9 +736,7 @@ async def on_connect(): cmd.id = i["id"] self._application_commands[cmd.id] = cmd - async def process_application_commands( - self, interaction: Interaction, auto_sync: bool | None = None - ) -> None: + async def process_application_commands(self, interaction: Interaction, auto_sync: bool | None = None) -> None: """|coro| This function processes the commands that have been registered @@ -855,11 +782,7 @@ async def process_application_commands( if guild_id: guild_id = int(guild_id) if cmd.name == interaction.data["name"] and ( # type: ignore - guild_id == cmd.guild_ids - or ( - isinstance(cmd.guild_ids, list) - and guild_id in cmd.guild_ids - ) + guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids) ): command = cmd break @@ -873,18 +796,14 @@ async def process_application_commands( return self._bot.dispatch("unknown_application_command", interaction) if interaction.type is InteractionType.auto_complete: - return self._bot.dispatch( - "application_command_auto_complete", interaction, command - ) + return self._bot.dispatch("application_command_auto_complete", interaction, command) ctx = await self.get_application_context(interaction) if command: ctx.command = command await self.invoke_application_command(ctx) - async def on_application_command_auto_complete( - self, interaction: Interaction, command: ApplicationCommand - ) -> None: + async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None: async def callback() -> None: ctx = await self.get_autocomplete_context(interaction) ctx.command = command @@ -1195,26 +1114,18 @@ def __init__(self, description=None, *args, **options): if self.owner_id and self.owner_ids: raise TypeError("Both owner_id and owner_ids are set.") - if self.owner_ids and not isinstance( - self.owner_ids, collections.abc.Collection - ): - raise TypeError( - f"owner_ids must be a collection not {self.owner_ids.__class__!r}" - ) + if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): + raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}") if not isinstance(self.default_command_contexts, collections.abc.Collection): raise TypeError( f"default_command_contexts must be a collection not {self.default_command_contexts.__class__!r}" ) - if not isinstance( - self.default_command_integration_types, collections.abc.Collection - ): + if not isinstance(self.default_command_integration_types, collections.abc.Collection): raise TypeError( f"default_command_integration_types must be a collection not {self.default_command_integration_types.__class__!r}" ) self.default_command_contexts = set(self.default_command_contexts) - self.default_command_integration_types = set( - self.default_command_integration_types - ) + self.default_command_integration_types = set(self.default_command_integration_types) self._checks = [] self._check_once = [] @@ -1228,9 +1139,7 @@ async def on_connect(self): async def on_interaction(self, interaction): await self.process_application_commands(interaction) - async def on_application_command_error( - self, context: ApplicationContext, exception: DiscordException - ) -> None: + async def on_application_command_error(self, context: ApplicationContext, exception: DiscordException) -> None: """|coro| The default command error handler provided by the bot. @@ -1251,9 +1160,7 @@ async def on_application_command_error( return print(f"Ignoring exception in command {context.command}:", file=sys.stderr) - traceback.print_exception( - type(exception), exception, exception.__traceback__, file=sys.stderr - ) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) # global check registration # TODO: Remove these from commands.Bot @@ -1346,9 +1253,7 @@ def whitelist(ctx): self.add_check(func, call_once=True) return func - async def can_run( - self, ctx: ApplicationContext, *, call_once: bool = False - ) -> bool: + async def can_run(self, ctx: ApplicationContext, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks if not data: diff --git a/discord/channel.py b/discord/channel.py index 8fef92eaca..3470acc981 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -131,9 +131,7 @@ class ForumTag(Hashable): __slots__ = ("name", "id", "moderated", "emoji") - def __init__( - self, *, name: str, emoji: EmojiInputType, moderated: bool = False - ) -> None: + def __init__(self, *, name: str, emoji: EmojiInputType, moderated: bool = False) -> None: self.name: str = name self.id: int = 0 self.moderated: bool = moderated @@ -143,16 +141,10 @@ def __init__( elif isinstance(emoji, str): self.emoji = PartialEmoji.from_str(emoji) else: - raise TypeError( - "emoji must be a GuildEmoji, PartialEmoji, or str and not" - f" {emoji.__class__!r}" - ) + raise TypeError(f"emoji must be a GuildEmoji, PartialEmoji, or str and not {emoji.__class__!r}") def __repr__(self) -> str: - return ( - "" - ) + return f"" def __str__(self) -> str: return self.name @@ -223,9 +215,7 @@ def __repr__(self) -> str: joined = " ".join("%s=%r" % t for t in attrs) return f"<{self.__class__.__name__} {joined}>" - def _update( - self, guild: Guild, data: TextChannelPayload | ForumChannelPayload - ) -> None: + def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) -> None: # This data will always exist self.guild: Guild = guild self.name: str = data["name"] @@ -238,15 +228,9 @@ def _update( self.nsfw: bool = data.get("nsfw", False) # Does this need coercion into `int`? No idea yet. self.slowmode_delay: int = data.get("rate_limit_per_user", 0) - self.default_auto_archive_duration: ThreadArchiveDuration = data.get( - "default_auto_archive_duration", 1440 - ) - self.default_thread_slowmode_delay: int | None = data.get( - "default_thread_rate_limit_per_user" - ) - self.last_message_id: int | None = utils._get_as_snowflake( - data, "last_message_id" - ) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) + self.default_thread_slowmode_delay: int | None = data.get("default_thread_rate_limit_per_user") + self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @@ -279,11 +263,7 @@ def threads(self) -> list[Thread]: .. versionadded:: 2.0 """ - return [ - thread - for thread in self.guild._threads.values() - if thread.parent_id == self.id - ] + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] def is_nsfw(self) -> bool: """Checks if the channel is NSFW.""" @@ -308,20 +288,14 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return ( - self._state._get_message(self.last_message_id) - if self.last_message_id - else None - ) + return self._state._get_message(self.last_message_id) if self.last_message_id else None async def edit(self, **options) -> _TextChannel: """Edits the channel.""" raise NotImplementedError @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: str | None = None, reason: str | None = None - ) -> TextChannel: + async def clone(self, *, name: str | None = None, reason: str | None = None) -> TextChannel: return await self._clone_impl( { "topic": self.topic, @@ -332,9 +306,7 @@ async def clone( reason=reason, ) - async def delete_messages( - self, messages: Iterable[Snowflake], *, reason: str | None = None - ) -> None: + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: str | None = None) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -451,8 +423,9 @@ async def purge( def is_me(m): return m.author == client.user + deleted = await channel.purge(limit=100, check=is_me) - await channel.send(f'Deleted {len(deleted)} message(s)') + await channel.send(f"Deleted {len(deleted)} message(s)") """ return await discord.abc._purge_messages_helper( self, @@ -489,9 +462,7 @@ async def webhooks(self) -> list[Webhook]: data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook( - self, *, name: str, avatar: bytes | None = None, reason: str | None = None - ) -> Webhook: + async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason: str | None = None) -> Webhook: """|coro| Creates a webhook for this channel. @@ -529,14 +500,10 @@ async def create_webhook( if avatar is not None: avatar = utils._bytes_to_base64_data(avatar) # type: ignore - data = await self._state.http.create_webhook( - self.id, name=str(name), avatar=avatar, reason=reason - ) + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) - async def follow( - self, *, destination: TextChannel, reason: str | None = None - ) -> Webhook: + async def follow(self, *, destination: TextChannel, reason: str | None = None) -> Webhook: """ Follows a channel using a webhook. @@ -575,15 +542,11 @@ async def follow( raise ClientException("The channel must be a news channel.") if not isinstance(destination, TextChannel): - raise InvalidArgument( - f"Expected TextChannel received {destination.__class__.__name__}" - ) + raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}") from .webhook import Webhook - data = await self._state.http.follow_webhook( - self.id, webhook_channel_id=destination.id, reason=reason - ) + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) return Webhook._as_follower(data, channel=destination, user=self._state.user) def get_partial_message(self, message_id: int, /) -> PartialMessage: @@ -741,9 +704,7 @@ class TextChannel(discord.abc.Messageable, _TextChannel): .. versionadded:: 2.3 """ - def __init__( - self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): super().__init__(state=state, guild=guild, data=data) @property @@ -923,8 +884,7 @@ async def create_thread( data = await self._state.http.start_thread_without_message( self.id, name=name, - auto_archive_duration=auto_archive_duration - or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, type=type.value, rate_limit_per_user=slowmode_delay or 0, invitable=invitable, @@ -935,8 +895,7 @@ async def create_thread( self.id, message.id, name=name, - auto_archive_duration=auto_archive_duration - or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, rate_limit_per_user=slowmode_delay or 0, reason=reason, ) @@ -1024,16 +983,13 @@ class ForumChannel(_TextChannel): .. versionadded:: 2.5 """ - def __init__( - self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload): super().__init__(state=state, guild=guild, data=data) def _update(self, guild: Guild, data: ForumChannelPayload) -> None: super()._update(guild, data) self.available_tags: list[ForumTag] = [ - ForumTag.from_data(state=self._state, data=tag) - for tag in (data.get("available_tags") or []) + ForumTag.from_data(state=self._state, data=tag) for tag in (data.get("available_tags") or []) ] self.default_sort_order: SortOrder | None = data.get("default_sort_order", None) if self.default_sort_order is not None: @@ -1266,27 +1222,21 @@ async def create_thread( message_content = str(content) if content is not None else None if embed is not None and embeds is not None: - raise InvalidArgument( - "cannot pass both embed and embeds parameter to create_thread()" - ) + raise InvalidArgument("cannot pass both embed and embeds parameter to create_thread()") if embed is not None: embed = embed.to_dict() elif embeds is not None: if len(embeds) > 10: - raise InvalidArgument( - "embeds parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") embeds = [embed.to_dict() for embed in embeds] if stickers is not None: stickers = [sticker.id for sticker in stickers] if allowed_mentions is None: - allowed_mentions = ( - state.allowed_mentions and state.allowed_mentions.to_dict() - ) + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() elif state.allowed_mentions is not None: allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() else: @@ -1294,9 +1244,7 @@ async def create_thread( if view: if not hasattr(view, "__discord_ui_view__"): - raise InvalidArgument( - f"view parameter must be View not {view.__class__!r}" - ) + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() else: @@ -1310,9 +1258,7 @@ async def create_thread( if files is not None: if len(files) > 10: - raise InvalidArgument( - "files parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") @@ -1333,8 +1279,7 @@ async def create_thread( allowed_mentions=allowed_mentions, stickers=stickers, components=components, - auto_archive_duration=auto_archive_duration - or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, rate_limit_per_user=slowmode_delay or self.slowmode_delay, applied_tags=applied_tags, reason=reason, @@ -1580,9 +1525,7 @@ def _get_voice_client_key(self) -> tuple[int, str]: def _get_voice_state_pair(self) -> tuple[int, int]: return self.guild.id, self.id - def _update( - self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload - ) -> None: + def _update(self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload) -> None: # This data will always exist self.guild = guild self.name: str = data["name"] @@ -1591,15 +1534,9 @@ def _update( # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): rtc = data.get("rtc_region") - self.rtc_region: VoiceRegion | None = ( - try_enum(VoiceRegion, rtc) if rtc is not None else None - ) - self.video_quality_mode: VideoQualityMode = try_enum( - VideoQualityMode, data.get("video_quality_mode", 1) - ) - self.last_message_id: int | None = utils._get_as_snowflake( - data, "last_message_id" - ) + self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") self.position: int = data.get("position") self.slowmode_delay = data.get("rate_limit_per_user", 0) self.bitrate: int = data.get("bitrate") @@ -1782,11 +1719,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return ( - self._state._get_message(self.last_message_id) - if self.last_message_id - else None - ) + return self._state._get_message(self.last_message_id) if self.last_message_id else None def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. @@ -1811,9 +1744,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: return PartialMessage(channel=self, id=message_id) - async def delete_messages( - self, messages: Iterable[Snowflake], *, reason: str | None = None - ) -> None: + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: str | None = None) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -1930,8 +1861,9 @@ async def purge( def is_me(m): return m.author == client.user + deleted = await channel.purge(limit=100, check=is_me) - await channel.send(f'Deleted {len(deleted)} message(s)') + await channel.send(f"Deleted {len(deleted)} message(s)") """ return await discord.abc._purge_messages_helper( self, @@ -1968,9 +1900,7 @@ async def webhooks(self) -> list[Webhook]: data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook( - self, *, name: str, avatar: bytes | None = None, reason: str | None = None - ) -> Webhook: + async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason: str | None = None) -> Webhook: """|coro| Creates a webhook for this channel. @@ -2008,9 +1938,7 @@ async def create_webhook( if avatar is not None: avatar = utils._bytes_to_base64_data(avatar) # type: ignore - data = await self._state.http.create_webhook( - self.id, name=str(name), avatar=avatar, reason=reason - ) + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @property @@ -2019,9 +1947,7 @@ def type(self) -> ChannelType: return ChannelType.voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: str | None = None, reason: str | None = None - ) -> VoiceChannel: + async def clone(self, *, name: str | None = None, reason: str | None = None) -> VoiceChannel: return await self._clone_impl( {"bitrate": self.bitrate, "user_limit": self.user_limit}, name=name, @@ -2113,9 +2039,7 @@ async def edit(self, *, reason=None, **options): # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - async def create_activity_invite( - self, activity: EmbeddedActivity | int, **kwargs - ) -> Invite: + async def create_activity_invite(self, activity: EmbeddedActivity | int, **kwargs) -> Invite: """|coro| A shortcut method that creates an instant activity invite. @@ -2168,9 +2092,7 @@ async def create_activity_invite( **kwargs, ) - async def set_status( - self, status: str | None, *, reason: str | None = None - ) -> None: + async def set_status(self, status: str | None, *, reason: str | None = None) -> None: """|coro| Sets the status of the voice channel. @@ -2276,11 +2198,7 @@ def __repr__(self) -> str: @property def requesting_to_speak(self) -> list[Member]: """A list of members who are requesting to speak in the stage channel.""" - return [ - member - for member in self.members - if member.voice and member.voice.requested_to_speak_at is not None - ] + return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] @property def speakers(self) -> list[Member]: @@ -2291,9 +2209,7 @@ def speakers(self) -> list[Member]: return [ member for member in self.members - if member.voice - and not member.voice.suppress - and member.voice.requested_to_speak_at is None + if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None ] @property @@ -2302,9 +2218,7 @@ def listeners(self) -> list[Member]: .. versionadded:: 2.0 """ - return [ - member for member in self.members if member.voice and member.voice.suppress - ] + return [member for member in self.members if member.voice and member.voice.suppress] async def _get_channel(self): return self @@ -2332,11 +2246,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return ( - self._state._get_message(self.last_message_id) - if self.last_message_id - else None - ) + return self._state._get_message(self.last_message_id) if self.last_message_id else None def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. @@ -2361,9 +2271,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: return PartialMessage(channel=self, id=message_id) - async def delete_messages( - self, messages: Iterable[Snowflake], *, reason: str | None = None - ) -> None: + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: str | None = None) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -2480,8 +2388,9 @@ async def purge( def is_me(m): return m.author == client.user + deleted = await channel.purge(limit=100, check=is_me) - await channel.send(f'Deleted {len(deleted)} message(s)') + await channel.send(f"Deleted {len(deleted)} message(s)") """ return await discord.abc._purge_messages_helper( self, @@ -2518,9 +2427,7 @@ async def webhooks(self) -> list[Webhook]: data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook( - self, *, name: str, avatar: bytes | None = None, reason: str | None = None - ) -> Webhook: + async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason: str | None = None) -> Webhook: """|coro| Creates a webhook for this channel. @@ -2558,9 +2465,7 @@ async def create_webhook( if avatar is not None: avatar = utils._bytes_to_base64_data(avatar) # type: ignore - data = await self._state.http.create_webhook( - self.id, name=str(name), avatar=avatar, reason=reason - ) + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @property @@ -2570,11 +2475,7 @@ def moderators(self) -> list[Member]: .. versionadded:: 2.0 """ required_permissions = Permissions.stage_moderator() - return [ - member - for member in self.members - if self.permissions_for(member) >= required_permissions - ] + return [member for member in self.members if self.permissions_for(member) >= required_permissions] @property def type(self) -> ChannelType: @@ -2582,9 +2483,7 @@ def type(self) -> ChannelType: return ChannelType.stage_voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: str | None = None, reason: str | None = None - ) -> StageChannel: + async def clone(self, *, name: str | None = None, reason: str | None = None) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @property @@ -2647,9 +2546,7 @@ async def create_instance( if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument( - "privacy_level field must be of type PrivacyLevel" - ) + raise InvalidArgument("privacy_level field must be of type PrivacyLevel") payload["privacy_level"] = privacy_level.value @@ -2814,18 +2711,13 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): "flags", ) - def __init__( - self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self._update(guild, data) def __repr__(self) -> str: - return ( - "" - ) + return f"" def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: # This data will always exist @@ -2854,9 +2746,7 @@ def is_nsfw(self) -> bool: return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: str | None = None, reason: str | None = None - ) -> CategoryChannel: + async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload @@ -2943,22 +2833,14 @@ def comparator(channel): @property def text_channels(self) -> list[TextChannel]: """Returns the text channels that are under this category.""" - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, TextChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] ret.sort(key=lambda c: (c.position or -1, c.id)) return ret @property def voice_channels(self) -> list[VoiceChannel]: """Returns the voice channels that are under this category.""" - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, VoiceChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] ret.sort(key=lambda c: (c.position or -1, c.id)) return ret @@ -2968,11 +2850,7 @@ def stage_channels(self) -> list[StageChannel]: .. versionadded:: 1.7 """ - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, StageChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] ret.sort(key=lambda c: (c.position or -1, c.id)) return ret @@ -2982,11 +2860,7 @@ def forum_channels(self) -> list[ForumChannel]: .. versionadded:: 2.0 """ - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, ForumChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, ForumChannel)] ret.sort(key=lambda c: (c.position or -1, c.id)) return ret @@ -3081,9 +2955,7 @@ class DMChannel(discord.abc.Messageable, Hashable): __slots__ = ("id", "recipient", "me", "_state") - def __init__( - self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload - ): + def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): self._state: ConnectionState = state self.recipient: User | None = None if r := data.get("recipients"): @@ -3234,9 +3106,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): "_state", ) - def __init__( - self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload - ): + def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self.me: ClientUser = me @@ -3246,9 +3116,7 @@ def _update_group(self, data: GroupChannelPayload) -> None: self.owner_id: int | None = utils._get_as_snowflake(data, "owner_id") self._icon: str | None = data.get("icon") self.name: str | None = data.get("name") - self.recipients: list[User] = [ - self._state.store_user(u) for u in data.get("recipients", []) - ] + self.recipients: list[User] = [self._state.store_user(u) for u in data.get("recipients", [])] self.owner: BaseUser | None if self.owner_id == self.me.id: @@ -3380,9 +3248,7 @@ class PartialMessageable(discord.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__( - self, state: ConnectionState, id: int, type: ChannelType | None = None - ): + def __init__(self, state: ConnectionState, id: int, type: ChannelType | None = None): self._state: ConnectionState = state self._channel: Object = Object(id=id) self.id: int = id diff --git a/discord/client.py b/discord/client.py index b21f54dc30..4def3449d6 100644 --- a/discord/client.py +++ b/discord/client.py @@ -231,12 +231,8 @@ def __init__( self._banner_module = options.get("banner_module") # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) - self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = ( - {} - ) + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: int | None = options.get("shard_id") self.shard_count: int | None = options.get("shard_count") @@ -254,9 +250,7 @@ def __init__( self._handlers: dict[str, Callable] = {"ready": self._handle_ready} - self._hooks: dict[str, Callable] = { - "before_identify": self._call_before_identify_hook - } + self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook} self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) @@ -295,9 +289,7 @@ async def __aexit__( # internals - def _get_websocket( - self, guild_id: int | None = None, *, shard_id: int | None = None - ) -> DiscordWebSocket: + def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket: return self.ws def _get_state(self, **options: Any) -> ConnectionState: @@ -548,17 +540,13 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: # hooks - async def _call_before_identify_hook( - self, shard_id: int | None, *, initial: bool = False - ) -> None: + async def _call_before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook( - self, shard_id: int | None, *, initial: bool = False - ) -> None: + async def before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -605,16 +593,14 @@ async def login(self, token: str) -> None: passing status code. """ if not isinstance(token, str): - raise TypeError( - f"token must be of type str, not {token.__class__.__name__}" - ) + raise TypeError(f"token must be of type str, not {token.__class__.__name__}") _log.info("logging in using static token") data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) - print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or 'discord') + print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord") start_logging(self._flavor, debug=self._debug) async def connect(self, *, reconnect: bool = True) -> None: @@ -711,9 +697,7 @@ async def connect(self, *, reconnect: bool = True) -> None: # This is apparently what the official Discord client does. if self.ws is None: continue - ws_params.update( - sequence=self.ws.sequence, resume=True, session=self.ws.session_id - ) + ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) async def close(self) -> None: """|coro| @@ -881,9 +865,7 @@ def allowed_mentions(self, value: AllowedMentions | None) -> None: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError( - f"allowed_mentions must be AllowedMentions not {value.__class__!r}" - ) + raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}") @property def intents(self) -> Intents: @@ -957,9 +939,7 @@ def get_message(self, id: int, /) -> Message | None: """ return self._connection._get_message(id) - def get_partial_messageable( - self, id: int, *, type: ChannelType | None = None - ) -> PartialMessageable: + def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) -> PartialMessageable: """Returns a partial messageable with the given channel ID. This is useful if you have a channel_id but don't want to do an API call @@ -1206,33 +1186,33 @@ def wait_for( @client.event async def on_message(message): - if message.content.startswith('$greet'): + if message.content.startswith("$greet"): channel = message.channel - await channel.send('Say hello!') + await channel.send("Say hello!") def check(m): - return m.content == 'hello' and m.channel == channel + return m.content == "hello" and m.channel == channel - msg = await client.wait_for('message', check=check) - await channel.send(f'Hello {msg.author}!') + msg = await client.wait_for("message", check=check) + await channel.send(f"Hello {msg.author}!") Waiting for a thumbs up reaction from the message author: :: @client.event async def on_message(message): - if message.content.startswith('$thumb'): + if message.content.startswith("$thumb"): channel = message.channel - await channel.send('Send me that \N{THUMBS UP SIGN} reaction, mate') + await channel.send("Send me that \N{THUMBS UP SIGN} reaction, mate") def check(reaction, user): - return user == message.author and str(reaction.emoji) == '\N{THUMBS UP SIGN}' + return user == message.author and str(reaction.emoji) == "\N{THUMBS UP SIGN}" try: - reaction, user = await client.wait_for('reaction_add', timeout=60.0, check=check) + reaction, user = await client.wait_for("reaction_add", timeout=60.0, check=check) except asyncio.TimeoutError: - await channel.send('\N{THUMBS DOWN SIGN}') + await channel.send("\N{THUMBS DOWN SIGN}") else: - await channel.send('\N{THUMBS UP SIGN}') + await channel.send("\N{THUMBS UP SIGN}") """ future = self.loop.create_future() @@ -1276,11 +1256,16 @@ def add_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> Non .. code-block:: python3 - async def on_ready(): pass - async def my_message(message): pass + async def on_ready(): + pass + + + async def my_message(message): + pass + client.add_listener(on_ready) - client.add_listener(my_message, 'on_message') + client.add_listener(my_message, "on_message") """ name = func.__name__ if name is MISSING else name @@ -1342,18 +1327,21 @@ def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> C @client.listen() async def on_message(message): - print('one') + print("one") + # in some other file... - @client.listen('on_message') + + @client.listen("on_message") async def my_message(message): - print('two') + print("two") + # listen to the first event only - @client.listen('on_ready', once=True) + @client.listen("on_ready", once=True) async def on_ready(): - print('ready!') + print("ready!") Would print one and two in an unspecified order. """ @@ -1399,7 +1387,7 @@ def event(self, coro: Coro) -> Coro: @client.event async def on_ready(): - print('Ready!') + print("Ready!") """ if not asyncio.iscoroutinefunction(coro): @@ -1530,9 +1518,7 @@ def fetch_guilds( All parameters are optional. """ - return GuildIterator( - self, limit=limit, before=before, after=after, with_counts=with_counts - ) + return GuildIterator(self, limit=limit, before=before, after=after, with_counts=with_counts) async def fetch_template(self, code: Template | str) -> Template: """|coro| @@ -1851,9 +1837,7 @@ async def fetch_user(self, user_id: int, /) -> User: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel( - self, channel_id: int, / - ) -> GuildChannel | PrivateChannel | Thread: + async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChannel | Thread: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. @@ -1884,9 +1868,7 @@ async def fetch_channel( factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData( - "Unknown channel type {type} for channel ID {id}.".format_map(data) - ) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here @@ -1960,10 +1942,7 @@ async def fetch_premium_sticker_packs(self) -> list[StickerPack]: Retrieving the sticker packs failed. """ data = await self.http.list_premium_sticker_packs() - return [ - StickerPack(state=self._connection, data=pack) - for pack in data["sticker_packs"] - ] + return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -2023,10 +2002,7 @@ def add_view(self, view: View, *, message_id: int | None = None) -> None: raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError( - "View is not persistent. Items need to have a custom_id set and View" - " must have no timeout" - ) + raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") self._connection.store_view(view, message_id) @@ -2052,9 +2028,7 @@ async def fetch_role_connection_metadata_records( List[:class:`.ApplicationRoleConnectionMetadata`] The bot's role connection metadata records. """ - data = await self._connection.http.get_application_role_connection_metadata_records( - self.application_id - ) + data = await self._connection.http.get_application_role_connection_metadata_records(self.application_id) return [ApplicationRoleConnectionMetadata.from_dict(r) for r in data] async def update_role_connection_metadata_records( @@ -2193,13 +2167,8 @@ async def fetch_emojis(self) -> list[AppEmoji]: List[:class:`AppEmoji`] The retrieved emojis. """ - data = await self._connection.http.get_all_application_emojis( - self.application_id - ) - return [ - self._connection.maybe_store_app_emoji(self.application_id, d) - for d in data["items"] - ] + data = await self._connection.http.get_all_application_emojis(self.application_id) + return [self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]] async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: """|coro| @@ -2223,9 +2192,7 @@ async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: HTTPException An error occurred fetching the emoji. """ - data = await self._connection.http.get_application_emoji( - self.application_id, emoji_id - ) + data = await self._connection.http.get_application_emoji(self.application_id, emoji_id) return self._connection.maybe_store_app_emoji(self.application_id, data) async def create_emoji( @@ -2260,9 +2227,7 @@ async def create_emoji( """ img = utils._bytes_to_base64_data(image) - data = await self._connection.http.create_application_emoji( - self.application_id, name, img - ) + data = await self._connection.http.create_application_emoji(self.application_id, name, img) return self._connection.maybe_store_app_emoji(self.application_id, data) async def delete_emoji(self, emoji: Snowflake) -> None: @@ -2281,8 +2246,6 @@ async def delete_emoji(self, emoji: Snowflake) -> None: An error occurred deleting the emoji. """ - await self._connection.http.delete_application_emoji( - self.application_id, emoji.id - ) + await self._connection.http.delete_application_emoji(self.application_id, emoji.id) if self._connection.cache_app_emojis and self._connection.get_emoji(emoji.id): self._connection.remove_emoji(emoji) diff --git a/discord/cog.py b/discord/cog.py index b402e6f7c3..2200237501 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -72,12 +72,15 @@ class CogMeta(type): import abc + class CogABCMeta(discord.CogMeta, abc.ABCMeta): pass + class SomeMixin(metaclass=abc.ABCMeta): pass + class SomeCogMixin(SomeMixin, discord.Cog, metaclass=CogABCMeta): pass @@ -89,7 +92,7 @@ class SomeCogMixin(SomeMixin, discord.Cog, metaclass=CogABCMeta): .. code-block:: python3 - class MyCog(discord.Cog, name='My Cog'): + class MyCog(discord.Cog, name="My Cog"): pass Attributes @@ -112,11 +115,11 @@ class MyCog(discord.Cog, name='My Cog'): class MyCog(discord.Cog, command_attrs=dict(hidden=True)): @discord.slash_command() async def foo(self, ctx): - pass # hidden -> True + pass # hidden -> True @discord.slash_command(hidden=False) async def bar(self, ctx): - pass # hidden -> False + pass # hidden -> False guild_ids: Optional[List[:class:`int`]] A shortcut to :attr:`.command_attrs`, what ``guild_ids`` should all application commands have @@ -144,10 +147,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: commands = {} listeners = {} - no_bot_cog = ( - "Commands or listeners must not start with cog_ or bot_ (in method" - " {0.__name__}.{1})" - ) + no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})" new_cls = super().__new__(cls, name, bases, attrs, **kwargs) @@ -158,9 +158,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: if elem in listeners: del listeners[elem] - if getattr(value, "parent", None) and isinstance( - value, ApplicationCommand - ): + if getattr(value, "parent", None) and isinstance(value, ApplicationCommand): # Skip commands if they are a part of a group continue @@ -169,10 +167,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: value = value.__func__ if isinstance(value, _BaseCommand): if is_static_method: - raise TypeError( - f"Command in method {base}.{elem!r} must not be" - " staticmethod." - ) + raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.") if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value @@ -180,10 +175,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: # a test to see if this value is a BridgeCommand if hasattr(value, "add_to") and not getattr(value, "parent", None): if is_static_method: - raise TypeError( - f"Command in method {base}.{elem!r} must not be" - " staticmethod." - ) + raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.") if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) @@ -191,9 +183,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: commands[f"app_{elem}"] = value.slash_variant commands[elem] = value for cmd in getattr(value, "subcommands", []): - commands[f"ext_{cmd.ext_variant.qualified_name}"] = ( - cmd.ext_variant - ) + commands[f"ext_{cmd.ext_variant.qualified_name}"] = cmd.ext_variant if inspect.iscoroutinefunction(value): try: @@ -220,31 +210,22 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: # Either update the command with the cog provided defaults or copy it. # r.e type ignore, type-checker complains about overriding a ClassVar - new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__) # type: ignore + new_cls.__cog_commands__ = tuple( + c._update_copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__ + ) # type: ignore name_filter = lambda c: ( - "app" - if isinstance(c, ApplicationCommand) - else ("bridge" if not hasattr(c, "add_to") else "ext") + "app" if isinstance(c, ApplicationCommand) else ("bridge" if not hasattr(c, "add_to") else "ext") ) - lookup = { - f"{name_filter(cmd)}_{cmd.qualified_name}": cmd - for cmd in new_cls.__cog_commands__ - } + lookup = {f"{name_filter(cmd)}_{cmd.qualified_name}": cmd for cmd in new_cls.__cog_commands__} # Update the Command instances dynamically as well for command in new_cls.__cog_commands__: - if ( - isinstance(command, ApplicationCommand) - and not command.guild_ids - and new_cls.__cog_guild_ids__ - ): + if isinstance(command, ApplicationCommand) and not command.guild_ids and new_cls.__cog_guild_ids__: command.guild_ids = new_cls.__cog_guild_ids__ - if not isinstance(command, SlashCommandGroup) and not hasattr( - command, "add_to" - ): + if not isinstance(command, SlashCommandGroup) and not hasattr(command, "add_to"): # ignore bridge commands cmd = getattr(new_cls, command.callback.__name__, None) if hasattr(cmd, "add_to"): @@ -315,11 +296,7 @@ def get_commands(self) -> list[ApplicationCommand]: This does not include subcommands. """ - return [ - c - for c in self.__cog_commands__ - if isinstance(c, ApplicationCommand) and c.parent is None - ] + return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None] @property def qualified_name(self) -> str: @@ -355,22 +332,15 @@ def get_listeners(self) -> list[tuple[str, Callable[..., Any]]]: List[Tuple[:class:`str`, :ref:`coroutine `]] The listeners defined in this cog. """ - return [ - (name, getattr(self, method_name)) - for name, method_name in self.__cog_listeners__ - ] + return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] @classmethod def _get_overridden_method(cls, method: FuncT) -> FuncT | None: """Return None if the method is not overridden. Otherwise, returns the overridden method.""" - return getattr( - getattr(method, "__func__", method), "__cog_special_method__", method - ) + return getattr(getattr(method, "__func__", method), "__cog_special_method__", method) @classmethod - def listener( - cls, name: str | discord.utils.Undefined = MISSING, once: bool = False - ) -> Callable[[FuncT], FuncT]: + def listener(cls, name: str | discord.utils.Undefined = MISSING, once: bool = False) -> Callable[[FuncT], FuncT]: """A decorator that marks a function as a listener. This is the cog equivalent of :meth:`.Bot.listen`. @@ -392,10 +362,7 @@ def listener( """ if name is not MISSING and not isinstance(name, str): - raise TypeError( - "Cog.listener expected str but received" - f" {name.__class__.__name__!r} instead." - ) + raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.") def decorator(func: FuncT) -> FuncT: actual = func @@ -481,9 +448,7 @@ def cog_check(self, ctx: ApplicationContext) -> bool: return True @_cog_special_method - async def cog_command_error( - self, ctx: ApplicationContext, error: Exception - ) -> None: + async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. @@ -738,8 +703,7 @@ def _remove_module_references(self, name: str) -> None: remove = [ index for index, event in enumerate(event_list) - if event.__module__ is not None - and _is_submodule(name, event.__module__) + if event.__module__ is not None and _is_submodule(name, event.__module__) ] for index in reversed(remove): @@ -763,9 +727,7 @@ def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec( - self, spec: importlib.machinery.ModuleSpec, key: str - ) -> None: + def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib @@ -919,9 +881,7 @@ def load_extension( parts = list(ext_file.parts[:-1]) # Gets the file name without the extension parts.append(ext_file.stem) - loaded = self.load_extension( - ".".join(parts), package=package, recursive=recursive, store=store - ) + loaded = self.load_extension(".".join(parts), package=package, recursive=recursive, store=store) final_out.update(loaded) if store else final_out.extend(loaded) if isinstance(final_out, Exception): @@ -1011,14 +971,8 @@ def load_extensions( loaded_extensions = {} if store else [] for ext_path in names: - loaded = self.load_extension( - ext_path, package=package, recursive=recursive, store=store - ) - ( - loaded_extensions.update(loaded) - if store - else loaded_extensions.extend(loaded) - ) + loaded = self.load_extension(ext_path, package=package, recursive=recursive, store=store) + (loaded_extensions.update(loaded) if store else loaded_extensions.extend(loaded)) return loaded_extensions @@ -1104,11 +1058,7 @@ def reload_extension(self, name: str, *, package: str | None = None) -> None: raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules - modules = { - name: module - for name, module in sys.modules.items() - if _is_submodule(lib.__name__, name) - } + modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)} try: # Unload and then load the module... diff --git a/discord/colour.py b/discord/colour.py index 0ec77d5786..9cf98c8af8 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -75,9 +75,7 @@ class Colour: def __init__(self, value: int): if not isinstance(value, int): - raise TypeError( - f"Expected int parameter, received {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.") self.value: int = value diff --git a/discord/commands/context.py b/discord/commands/context.py index 532d8abe2a..a046afdde6 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -188,11 +188,7 @@ def me(self) -> Member | ClientUser | None: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message message contexts, or when :meth:`Intents.guilds` is absent. """ - return ( - self.interaction.guild.me - if self.interaction.guild is not None - else self.bot.user - ) + return self.interaction.guild.me if self.interaction.guild is not None else self.bot.user @cached_property def message(self) -> Message | None: @@ -255,8 +251,7 @@ def unselected_options(self) -> list[Option] | None: return [ option for option in self.command.options # type: ignore - if option.to_dict()["name"] - not in [opt["name"] for opt in self.selected_options] + if option.to_dict()["name"] not in [opt["name"] for opt in self.selected_options] ] else: return self.command.options # type: ignore @@ -269,9 +264,7 @@ def send_modal(self) -> Callable[..., Awaitable[Interaction]]: @property @discord.utils.copy_doc(Interaction.respond) - def respond( - self, *args, **kwargs - ) -> Callable[..., Awaitable[Interaction | WebhookMessage]]: + def respond(self, *args, **kwargs) -> Callable[..., Awaitable[Interaction | WebhookMessage]]: return self.interaction.respond @property @@ -281,8 +274,7 @@ def send_response(self) -> Callable[..., Awaitable[Interaction]]: return self.interaction.response.send_message else: raise RuntimeError( - "Interaction was already issued a response. Try using" - f" {type(self).__name__}.send_followup() instead." + f"Interaction was already issued a response. Try using {type(self).__name__}.send_followup() instead." ) @property @@ -292,8 +284,7 @@ def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: return self.followup.send else: raise RuntimeError( - "Interaction was not yet issued a response. Try using" - f" {type(self).__name__}.respond() first." + f"Interaction was not yet issued a response. Try using {type(self).__name__}.respond() first." ) @property diff --git a/discord/commands/core.py b/discord/commands/core.py index bee43341fd..881d20ed13 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -145,10 +145,7 @@ async def wrapped(arg): except Exception as exc: raise ApplicationCommandInvokeError(exc) from exc finally: - if ( - hasattr(command, "_max_concurrency") - and command._max_concurrency is not None - ): + if hasattr(command, "_max_concurrency") and command._max_concurrency is not None: await command._max_concurrency.release(ctx) await command.call_after_hooks(ctx) return ret @@ -199,15 +196,11 @@ def __init__(self, func: Callable, **kwargs) -> None: elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets - max_concurrency = getattr( - func, "__commands_max_concurrency__", kwargs.get("max_concurrency") - ) + max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency")) self._max_concurrency: MaxConcurrency | None = max_concurrency @@ -235,9 +228,7 @@ def __init__(self, func: Callable, **kwargs) -> None: ) self.nsfw: bool | None = getattr(func, "__nsfw__", kwargs.get("nsfw", None)) - integration_types = getattr( - func, "__integration_types__", kwargs.get("integration_types", None) - ) + integration_types = getattr(func, "__integration_types__", kwargs.get("integration_types", None)) contexts = getattr(func, "__contexts__", kwargs.get("contexts", None)) guild_only = getattr(func, "__guild_only__", kwargs.get("guild_only", MISSING)) if guild_only is not MISSING: @@ -248,12 +239,8 @@ def __init__(self, func: Callable, **kwargs) -> None: reference="https://discord.com/developers/docs/change-log#userinstallable-apps-preview", ) if contexts and guild_only: - raise InvalidArgument( - "cannot pass both 'contexts' and 'guild_only' to ApplicationCommand" - ) - if self.guild_ids and ( - (contexts is not None) or guild_only or integration_types - ): + raise InvalidArgument("cannot pass both 'contexts' and 'guild_only' to ApplicationCommand") + if self.guild_ids and ((contexts is not None) or guild_only or integration_types): raise InvalidArgument( "the 'contexts' and 'integration_types' parameters are not available for guild commands" ) @@ -351,9 +338,7 @@ async def prepare(self, ctx: ApplicationContext) -> None: ctx.command = self if not await self.can_run(ctx): - raise CheckFailure( - f"The check functions for the command {self.name} failed" - ) + raise CheckFailure(f"The check functions for the command {self.name} failed") if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -436,9 +421,7 @@ async def invoke(self, ctx: ApplicationContext) -> None: async def can_run(self, ctx: ApplicationContext) -> bool: if not await ctx.bot.can_run(ctx): - raise CheckFailure( - f"The global check functions for command {self.name} failed." - ) + raise CheckFailure(f"The global check functions for command {self.name} failed.") predicates = self.checks if self.parent is not None: @@ -738,21 +721,15 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: raise TypeError("Callback must be a coroutine.") self.callback = func - self.name_localizations: dict[str, str] = kwargs.get( - "name_localizations", MISSING - ) + self.name_localizations: dict[str, str] = kwargs.get("name_localizations", MISSING) _validate_names(self) description = kwargs.get("description") or ( - inspect.cleandoc(func.__doc__).splitlines()[0] - if func.__doc__ is not None - else "No description provided" + inspect.cleandoc(func.__doc__).splitlines()[0] if func.__doc__ is not None else "No description provided" ) self.description: str = description - self.description_localizations: dict[str, str] = kwargs.get( - "description_localizations", MISSING - ) + self.description_localizations: dict[str, str] = kwargs.get("description_localizations", MISSING) _validate_descriptions(self) self.attached_to_group: bool = False @@ -781,16 +758,12 @@ def _validate_parameters(self): def _check_required_params(self, params): params = iter(params.items()) - required_params = ( - ["self", "context"] if self.attached_to_group or self.cog else ["context"] - ) + required_params = ["self", "context"] if self.attached_to_group or self.cog else ["context"] for p in required_params: try: next(params) except StopIteration: - raise ClientException( - f'Callback for {self.name} command is missing "{p}" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "{p}" parameter.') return params @@ -814,9 +787,7 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: option = next(option_gen, Option()) # Handle Optional if self._is_typing_optional(type_hint): - option.input_type = SlashCommandOptionType.from_datatype( - get_args(type_hint)[0] - ) + option.input_type = SlashCommandOptionType.from_datatype(get_args(type_hint)[0]) option.default = None else: option.input_type = SlashCommandOptionType.from_datatype(type_hint) @@ -830,9 +801,7 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: if not isinstance(option, Option): if isinstance(p_obj.default, Option): if p_obj.default.input_type is None: - p_obj.default.input_type = SlashCommandOptionType.from_datatype( - option - ) + p_obj.default.input_type = SlashCommandOptionType.from_datatype(option) option = p_obj.default else: option = Option(option) @@ -840,9 +809,7 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: if option.default is None and not p_obj.default == inspect.Parameter.empty: if isinstance(p_obj.default, Option): pass - elif isinstance(p_obj.default, type) and issubclass( - p_obj.default, (DiscordEnum, Enum) - ): + elif isinstance(p_obj.default, type) and issubclass(p_obj.default, (DiscordEnum, Enum)): option = Option(p_obj.default) else: option.default = p_obj.default @@ -866,15 +833,10 @@ def _match_option_param_names(self, params, options): check_annotations: list[Callable[[Option, type], bool]] = [ lambda o, a: o.input_type == SlashCommandOptionType.string and o.converter is not None, # pass on converters - lambda o, a: isinstance( - o.input_type, SlashCommandOptionType - ), # pass on slash cmd option type enums + lambda o, a: isinstance(o.input_type, SlashCommandOptionType), # pass on slash cmd option type enums lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types - lambda o, a: self._is_typing_optional(a) - and not o.required - and o._raw_type in a.__args__, # optional - lambda o, a: isinstance(a, type) - and issubclass(a, o._raw_type), # 'normal' types + lambda o, a: self._is_typing_optional(a) and not o.required and o._raw_type in a.__args__, # optional + lambda o, a: isinstance(a, type) and issubclass(a, o._raw_type), # 'normal' types ] for o in options: _validate_names(o) @@ -886,9 +848,7 @@ def _match_option_param_names(self, params, options): p_obj = p_obj.annotation if not any(check(o, p_obj) for check in check_annotations): - raise TypeError( - f"Parameter {p_name} does not match input type of {o.name}." - ) + raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.") o._parameter_name = p_name left_out_params = OrderedDict() @@ -899,9 +859,7 @@ def _match_option_param_names(self, params, options): return options def _is_typing_union(self, annotation): - return getattr(annotation, "__origin__", None) is Union or type( - annotation - ) is getattr( + return getattr(annotation, "__origin__", None) is Union or type(annotation) is getattr( types, "UnionType", Union ) # type: ignore @@ -920,12 +878,7 @@ def cog(self, value): old_cog = self.cog self._cog = value - if ( - old_cog is None - and value is not None - or value is None - and old_cog is not None - ): + if old_cog is None and value is not None or value is None and old_cog is not None: self._validate_parameters() @property @@ -953,9 +906,7 @@ def to_dict(self) -> dict: as_dict["nsfw"] = self.nsfw if self.default_member_permissions is not None: - as_dict["default_member_permissions"] = ( - self.default_member_permissions.value - ) + as_dict["default_member_permissions"] = self.default_member_permissions.value if not self.guild_ids and not self.is_subcommand: as_dict["integration_types"] = [it.value for it in self.integration_types] @@ -982,8 +933,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: ): resolved = ctx.interaction.data.get("resolved", {}) if ( - op.input_type - in (SlashCommandOptionType.user, SlashCommandOptionType.mentionable) + op.input_type in (SlashCommandOptionType.user, SlashCommandOptionType.mentionable) and (_data := resolved.get("members", {}).get(arg)) is not None ): # The option type is a user, we resolved a member from the snowflake and assigned it to _data @@ -996,25 +946,16 @@ async def _invoke(self, ctx: ApplicationContext) -> None: if (_data := resolved.get("users", {}).get(arg)) is not None: arg = User(state=ctx.interaction._state, data=_data) elif (_data := resolved.get("roles", {}).get(arg)) is not None: - arg = Role( - state=ctx.interaction._state, data=_data, guild=ctx.guild - ) + arg = Role(state=ctx.interaction._state, data=_data, guild=ctx.guild) else: arg = Object(id=int(arg)) - elif ( - _data := resolved.get(f"{op.input_type.name}s", {}).get(arg) - ) is not None: + elif (_data := resolved.get(f"{op.input_type.name}s", {}).get(arg)) is not None: if op.input_type is SlashCommandOptionType.channel and ( - int(arg) in ctx.guild._channels - or int(arg) in ctx.guild._threads + int(arg) in ctx.guild._channels or int(arg) in ctx.guild._threads ): arg = ctx.guild.get_channel_or_thread(int(arg)) _data["_invoke_flag"] = True - ( - arg._update(_data) - if isinstance(arg, Thread) - else arg._update(ctx.guild, _data) - ) + (arg._update(_data) if isinstance(arg, Thread) else arg._update(ctx.guild, _data)) else: obj_type = None kw = {} @@ -1039,10 +980,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: # We couldn't resolve the object, so we just return an empty object arg = Object(id=int(arg)) - elif ( - op.input_type == SlashCommandOptionType.string - and (converter := op.converter) is not None - ): + elif op.input_type == SlashCommandOptionType.string and (converter := op.converter) is not None: from discord.ext.commands import Converter if isinstance(converter, Converter): @@ -1087,9 +1025,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): for op in ctx.interaction.data.get("options", []): if op.get("focused", False): option = find(lambda o: o.name == op["name"], self.options) - values.update( - {i["name"]: i["value"] for i in ctx.interaction.data["options"]} - ) + values.update({i["name"]: i["value"] for i in ctx.interaction.data["options"]}) ctx.command = self ctx.focused = option ctx.value = op.get("value") @@ -1104,13 +1040,8 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): if asyncio.iscoroutinefunction(option.autocomplete): result = await result - choices = [ - o if isinstance(o, OptionChoice) else OptionChoice(o) - for o in result - ][:25] - return await ctx.interaction.response.send_autocomplete_result( - choices=choices - ) + choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] + return await ctx.interaction.response.send_autocomplete_result(choices=choices) def copy(self): """Creates a copy of this command. @@ -1238,9 +1169,7 @@ def __init__( validate_chat_input_name(self.name) validate_chat_input_description(self.description) self.input_type = SlashCommandOptionType.sub_command_group - self.subcommands: list[SlashCommand | SlashCommandGroup] = ( - self.__initial_commands__ - ) + self.subcommands: list[SlashCommand | SlashCommandGroup] = self.__initial_commands__ self.guild_ids = guild_ids self.parent = parent self.attached_to_group: bool = False @@ -1252,9 +1181,7 @@ def __init__( self.id = None # Permissions - self.default_member_permissions: Permissions | None = kwargs.get( - "default_member_permissions", None - ) + self.default_member_permissions: Permissions | None = kwargs.get("default_member_permissions", None) self.nsfw: bool | None = kwargs.get("nsfw", None) integration_types = kwargs.get("integration_types", None) @@ -1263,12 +1190,8 @@ def __init__( if guild_only is not MISSING: warn_deprecated("guild_only", "contexts", "2.6") if contexts and guild_only: - raise InvalidArgument( - "cannot pass both 'contexts' and 'guild_only' to ApplicationCommand" - ) - if self.guild_ids and ( - (contexts is not None) or guild_only or integration_types - ): + raise InvalidArgument("cannot pass both 'contexts' and 'guild_only' to ApplicationCommand") + if self.guild_ids and ((contexts is not None) or guild_only or integration_types): raise InvalidArgument( "the 'contexts' and 'integration_types' parameters are not available for guild commands" ) @@ -1279,12 +1202,8 @@ def __init__( self.guild_only: bool | None = guild_only self.integration_types: set[IntegrationType] | None = integration_types - self.name_localizations: dict[str, str] = kwargs.get( - "name_localizations", MISSING - ) - self.description_localizations: dict[str, str] = kwargs.get( - "description_localizations", MISSING - ) + self.name_localizations: dict[str, str] = kwargs.get("name_localizations", MISSING) + self.description_localizations: dict[str, str] = kwargs.get("description_localizations", MISSING) # similar to ApplicationCommand from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency @@ -1296,20 +1215,14 @@ def __init__( elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets # no need to getattr, since slash cmds groups cant be created using a decorator - if max_concurrency is not None and not isinstance( - max_concurrency, MaxConcurrency - ): - raise TypeError( - "max_concurrency must be an instance of MaxConcurrency or None" - ) + if max_concurrency is not None and not isinstance(max_concurrency, MaxConcurrency): + raise TypeError("max_concurrency must be an instance of MaxConcurrency or None") self._max_concurrency: MaxConcurrency | None = max_concurrency @@ -1352,9 +1265,7 @@ def to_dict(self) -> dict: as_dict["nsfw"] = self.nsfw if self.default_member_permissions is not None: - as_dict["default_member_permissions"] = ( - self.default_member_permissions.value - ) + as_dict["default_member_permissions"] = self.default_member_permissions.value if not self.guild_ids and self.parent is None: as_dict["integration_types"] = [it.value for it in self.integration_types] @@ -1368,9 +1279,7 @@ def add_command(self, command: SlashCommand | SlashCommandGroup) -> None: self.subcommands.append(command) - def command( - self, cls: type[T] = SlashCommand, **kwargs - ) -> Callable[[Callable], SlashCommand]: + def command(self, cls: type[T] = SlashCommand, **kwargs) -> Callable[[Callable], SlashCommand]: def wrap(func) -> T: command = cls(func, parent=self, **kwargs) self.add_command(command) @@ -1427,9 +1336,7 @@ def create_subgroup( if self.parent is not None: raise Exception("A subcommand group cannot be added to a subcommand group") - sub_command_group = SlashCommandGroup( - name, description, guild_ids, parent=self, **kwargs - ) + sub_command_group = SlashCommandGroup(name, description, guild_ids, parent=self, **kwargs) self.subcommands.append(sub_command_group) return sub_command_group @@ -1633,9 +1540,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: raise TypeError("Callback must be a coroutine.") self.callback = func - self.name_localizations: dict[str, str] = kwargs.get( - "name_localizations", MISSING - ) + self.name_localizations: dict[str, str] = kwargs.get("name_localizations", MISSING) # Discord API doesn't support setting descriptions for context menu commands, so it must be empty self.description = "" @@ -1665,25 +1570,19 @@ def validate_parameters(self): try: next(params) except StopIteration: - raise ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') # next we have the 'user/message' as the next parameter try: next(params) except StopIteration: cmd = "user" if type(self) == UserCommand else "message" - raise ClientException( - f'Callback for {self.name} command is missing "{cmd}" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "{cmd}" parameter.') # next there should be no more parameters try: next(params) - raise ClientException( - f"Callback for {self.name} command has too many parameters." - ) + raise ClientException(f"Callback for {self.name} command has too many parameters.") except StopIteration: pass @@ -1706,9 +1605,7 @@ def to_dict(self) -> dict[str, str | int]: as_dict["nsfw"] = self.nsfw if self.default_member_permissions is not None: - as_dict["default_member_permissions"] = ( - self.default_member_permissions.value - ) + as_dict["default_member_permissions"] = self.default_member_permissions.value if self.name_localizations: as_dict["name_localizations"] = self.name_localizations @@ -1892,9 +1789,7 @@ async def _invoke(self, ctx: ApplicationContext): channel = ctx.interaction.channel if channel.id != int(message["channel_id"]): # we got weird stuff going on, make up a channel - channel = PartialMessageable( - state=ctx.interaction._state, id=int(message["channel_id"]) - ) + channel = PartialMessageable(state=ctx.interaction._state, id=int(message["channel_id"])) target = Message(state=ctx.interaction._state, channel=channel, data=message) @@ -2016,9 +1911,7 @@ def decorator(func: Callable) -> cls: if isinstance(func, ApplicationCommand): func = func.callback elif not callable(func): - raise TypeError( - "func needs to be a callable or a subclass of ApplicationCommand." - ) + raise TypeError("func needs to be a callable or a subclass of ApplicationCommand.") return cls(func, **attrs) return decorator @@ -2082,14 +1975,11 @@ def validate_chat_input_name(name: Any, locale: str | None = None): # Must meet the regex ^[-_\w\d\u0901-\u097D\u0E00-\u0E7F]{1,32}$ if locale is not None and locale not in valid_locales: raise ValidationError( - f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for" - " list of supported locales." + f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for list of supported locales." ) error = None if not isinstance(name, str): - error = TypeError( - f'Command names and options must be of type str. Received "{name}"' - ) + error = TypeError(f'Command names and options must be of type str. Received "{name}"') elif not re.match(r"^[-_\w\d\u0901-\u097D\u0E00-\u0E7F]{1,32}$", name): error = ValidationError( r"Command names and options must follow the regex" @@ -2098,12 +1988,8 @@ def validate_chat_input_name(name: Any, locale: str | None = None): f" {docs}/interactions/application-commands#application-command-object-" f'application-command-naming. Received "{name}"' ) - elif ( - name.lower() != name - ): # Can't use islower() as it fails if none of the chars can be lowered. See #512. - error = ValidationError( - f'Command names and options must be lowercase. Received "{name}"' - ) + elif name.lower() != name: # Can't use islower() as it fails if none of the chars can be lowered. See #512. + error = ValidationError(f'Command names and options must be lowercase. Received "{name}"') if error: if locale: @@ -2114,19 +2000,14 @@ def validate_chat_input_name(name: Any, locale: str | None = None): def validate_chat_input_description(description: Any, locale: str | None = None): if locale is not None and locale not in valid_locales: raise ValidationError( - f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for" - " list of supported locales." + f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for list of supported locales." ) error = None if not isinstance(description, str): - error = TypeError( - "Command and option description must be of type str. Received" - f' "{description}"' - ) + error = TypeError(f'Command and option description must be of type str. Received "{description}"') elif not 1 <= len(description) <= 100: error = ValidationError( - "Command and option description must be 1-100 characters long. Received" - f' "{description}"' + f'Command and option description must be 1-100 characters long. Received "{description}"' ) if error: diff --git a/discord/commands/options.py b/discord/commands/options.py index 2809602a70..87d9770e0b 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -176,7 +176,7 @@ class Option: async def hello( ctx: discord.ApplicationContext, name: Option(str, "Enter your name"), - age: Option(int, "Enter your age", min_value=1, max_value=99, default=18) + age: Option(int, "Enter your age", min_value=1, max_value=99, default=18), # passing the default value makes an argument optional # you also can create optional argument using: # age: Option(int, "Enter your age") = 18 @@ -189,9 +189,7 @@ async def hello( input_type: SlashCommandOptionType converter: Converter | type[Converter] | None = None - def __init__( - self, input_type: InputType = str, /, description: str | None = None, **kwargs - ) -> None: + def __init__(self, input_type: InputType = str, /, description: str | None = None, **kwargs) -> None: self.name: str | None = kwargs.pop("name", None) if self.name is not None: self.name = str(self.name) @@ -215,9 +213,7 @@ def __init__( if value_class in SlashCommandOptionType.__members__ and all( isinstance(elem.value, value_class) for elem in enum_choices ): - input_type = SlashCommandOptionType.from_datatype( - enum_choices[0].value.__class__ - ) + input_type = SlashCommandOptionType.from_datatype(enum_choices[0].value.__class__) else: enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] input_type = SlashCommandOptionType.string @@ -230,18 +226,10 @@ def __init__( else: from ..ext.commands import Converter - if isinstance(input_type, tuple) and any( - issubclass(op, ApplicationContext) for op in input_type - ): - input_type = next( - op for op in input_type if issubclass(op, ApplicationContext) - ) + if isinstance(input_type, tuple) and any(issubclass(op, ApplicationContext) for op in input_type): + input_type = next(op for op in input_type if issubclass(op, ApplicationContext)) - if ( - isinstance(input_type, Converter) - or input_type_is_class - and issubclass(input_type, Converter) - ): + if isinstance(input_type, Converter) or input_type_is_class and issubclass(input_type, Converter): self.converter = input_type self._raw_type = str self.input_type = SlashCommandOptionType.string @@ -264,14 +252,8 @@ def __init__( else: self._raw_type = (input_type,) if not self.channel_types: - self.channel_types = [ - CHANNEL_TYPE_MAP[t] - for t in self._raw_type - if t is not GuildChannel - ] - self.required: bool = ( - kwargs.pop("required", True) if "default" not in kwargs else False - ) + self.channel_types = [CHANNEL_TYPE_MAP[t] for t in self._raw_type if t is not GuildChannel] + self.required: bool = kwargs.pop("required", True) if "default" not in kwargs else False self.default = kwargs.pop("default", None) self.autocomplete = kwargs.pop("autocomplete", None) @@ -283,8 +265,7 @@ def __init__( self.input_type = SlashCommandOptionType.string else: self.choices: list[OptionChoice] = enum_choices or [ - o if isinstance(o, OptionChoice) else OptionChoice(o) - for o in kwargs.pop("choices", []) + o if isinstance(o, OptionChoice) else OptionChoice(o) for o in kwargs.pop("choices", []) ] if self.input_type == SlashCommandOptionType.integer: @@ -318,47 +299,31 @@ def __init__( "Option does not take min_value or max_value if not of type " "SlashCommandOptionType.integer or SlashCommandOptionType.number" ) - if self.input_type != SlashCommandOptionType.string and ( - self.min_length or self.max_length - ): - raise AttributeError( - "Option does not take min_length or max_length if not of type str" - ) + if self.input_type != SlashCommandOptionType.string and (self.min_length or self.max_length): + raise AttributeError("Option does not take min_length or max_length if not of type str") if self.min_value is not None and not isinstance(self.min_value, minmax_types): - raise TypeError( - f"Expected {minmax_typehint} for min_value, got" - f' "{type(self.min_value).__name__}"' - ) + raise TypeError(f'Expected {minmax_typehint} for min_value, got "{type(self.min_value).__name__}"') if self.max_value is not None and not isinstance(self.max_value, minmax_types): - raise TypeError( - f"Expected {minmax_typehint} for max_value, got" - f' "{type(self.max_value).__name__}"' - ) + raise TypeError(f'Expected {minmax_typehint} for max_value, got "{type(self.max_value).__name__}"') if self.min_length is not None: if not isinstance(self.min_length, minmax_length_types): raise TypeError( - f"Expected {minmax_length_typehint} for min_length," - f' got "{type(self.min_length).__name__}"' + f'Expected {minmax_length_typehint} for min_length, got "{type(self.min_length).__name__}"' ) if self.min_length < 0 or self.min_length > 6000: - raise AttributeError( - "min_length must be between 0 and 6000 (inclusive)" - ) + raise AttributeError("min_length must be between 0 and 6000 (inclusive)") if self.max_length is not None: if not isinstance(self.max_length, minmax_length_types): raise TypeError( - f"Expected {minmax_length_typehint} for max_length," - f' got "{type(self.max_length).__name__}"' + f'Expected {minmax_length_typehint} for max_length, got "{type(self.max_length).__name__}"' ) if self.max_length < 1 or self.max_length > 6000: raise AttributeError("max_length must between 1 and 6000 (inclusive)") self.name_localizations = kwargs.pop("name_localizations", MISSING) - self.description_localizations = kwargs.pop( - "description_localizations", MISSING - ) + self.description_localizations = kwargs.pop("description_localizations", MISSING) if input_type is None: raise TypeError("input_type cannot be NoneType.") @@ -442,11 +407,7 @@ def option(name, input_type=None, **kwargs): def decorator(func): resolved_name = kwargs.pop("parameter_name", None) or name - itype = ( - kwargs.pop("type", None) - or input_type - or func.__annotations__.get(resolved_name, str) - ) + itype = kwargs.pop("type", None) or input_type or func.__annotations__.get(resolved_name, str) func.__annotations__[resolved_name] = Option(itype, name=name, **kwargs) return func diff --git a/discord/commands/permissions.py b/discord/commands/permissions.py index daf633b05a..2a2145673e 100644 --- a/discord/commands/permissions.py +++ b/discord/commands/permissions.py @@ -56,10 +56,11 @@ def default_permissions(**perms: bool) -> Callable: from discord import default_permissions + @bot.slash_command() @default_permissions(manage_messages=True) async def test(ctx): - await ctx.respond('You can manage messages.') + await ctx.respond("You can manage messages.") """ invalid = set(perms) - set(Permissions.VALID_FLAGS) @@ -69,9 +70,7 @@ async def test(ctx): def inner(command: Callable): if isinstance(command, ApplicationCommand): if command.parent is not None: - raise RuntimeError( - "Permission restrictions can only be set on top-level commands" - ) + raise RuntimeError("Permission restrictions can only be set on top-level commands") command.default_member_permissions = Permissions(**perms) else: command.__default_member_permissions__ = Permissions(**perms) @@ -91,6 +90,7 @@ def guild_only() -> Callable: from discord import guild_only + @bot.slash_command() @guild_only() async def test(ctx): @@ -122,6 +122,7 @@ def is_nsfw() -> Callable: from discord import is_nsfw + @bot.slash_command() @is_nsfw() async def test(ctx): diff --git a/discord/components.py b/discord/components.py index 9ce6bfad7c..0d9a927a66 100644 --- a/discord/components.py +++ b/discord/components.py @@ -119,9 +119,7 @@ class ActionRow(Component): def __init__(self, data: ComponentPayload): self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.children: list[Component] = [ - _component_factory(d) for d in data.get("components", []) - ] + self.children: list[Component] = [_component_factory(d) for d in data.get("components", [])] def to_dict(self) -> ActionRowPayload: return { @@ -349,12 +347,8 @@ def __init__(self, data: SelectMenuPayload): self.min_values: int = data.get("min_values", 1) self.max_values: int = data.get("max_values", 1) self.disabled: bool = data.get("disabled", False) - self.options: list[SelectOption] = [ - SelectOption.from_dict(option) for option in data.get("options", []) - ] - self.channel_types: list[ChannelType] = [ - try_enum(ChannelType, ct) for ct in data.get("channel_types", []) - ] + self.options: list[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])] + self.channel_types: list[ChannelType] = [try_enum(ChannelType, ct) for ct in data.get("channel_types", [])] def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { @@ -457,8 +451,7 @@ def emoji(self, value) -> None: value = value._to_partial() else: raise TypeError( - "expected emoji to be str, GuildEmoji, AppEmoji, or PartialEmoji, not" - f" {value.__class__}" + f"expected emoji to be str, GuildEmoji, AppEmoji, or PartialEmoji, not {value.__class__}" ) self._emoji = value diff --git a/discord/embeds.py b/discord/embeds.py index 81424050e2..cab91a5176 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -186,15 +186,15 @@ def __init__(self, url: str): def from_dict(cls, data: dict[str, str | int]) -> EmbedMedia: self = cls.__new__(cls) self.url = str(data.get("url")) - self.proxy_url = ( - str(proxy_url) if (proxy_url := data.get("proxy_url")) else None - ) + self.proxy_url = str(proxy_url) if (proxy_url := data.get("proxy_url")) else None self.height = int(height) if (height := data.get("height")) else None self.width = int(width) if (width := data.get("width")) else None return self def __repr__(self) -> str: - return f" height={self.height!r} width={self.width!r}>" + return ( + f" height={self.height!r} width={self.width!r}>" + ) class EmbedProvider: @@ -527,10 +527,7 @@ def colour(self, value: int | Colour | None): # type: ignore elif isinstance(value, int): self._colour = Colour(value=value) else: - raise TypeError( - "Expected discord.Colour, int, or None but received" - f" {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected discord.Colour, int, or None but received {value.__class__.__name__} instead.") color = colour @@ -547,10 +544,7 @@ def timestamp(self, value: datetime.datetime | None): elif value is None: self._timestamp = value else: - raise TypeError( - "Expected datetime.datetime or None. Received" - f" {value.__class__.__name__} instead" - ) + raise TypeError(f"Expected datetime.datetime or None. Received {value.__class__.__name__} instead") @property def footer(self) -> EmbedFooter | None: @@ -572,10 +566,7 @@ def footer(self, value: EmbedFooter | None): elif isinstance(value, EmbedFooter): self._footer = value.to_dict() else: - raise TypeError( - "Expected EmbedFooter or None. Received" - f" {value.__class__.__name__} instead" - ) + raise TypeError(f"Expected EmbedFooter or None. Received {value.__class__.__name__} instead") def set_footer( self: E, @@ -648,10 +639,7 @@ def image(self, value: str | EmbedMedia | None): elif isinstance(value, EmbedMedia): self.set_image(url=value.url) else: - raise TypeError( - "Expected discord.EmbedMedia, or None but received" - f" {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected discord.EmbedMedia, or None but received {value.__class__.__name__} instead.") def set_image(self: E, *, url: Any | None) -> E: """Sets the image for the embed content. @@ -722,10 +710,7 @@ def thumbnail(self, value: str | EmbedMedia | None): elif isinstance(value, EmbedMedia): self.set_thumbnail(url=value.url) else: - raise TypeError( - "Expected discord.EmbedMedia, or None but received" - f" {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected discord.EmbedMedia, or None but received {value.__class__.__name__} instead.") def set_thumbnail(self: E, *, url: Any | None) -> E: """Sets the thumbnail for the embed content. @@ -819,10 +804,7 @@ def author(self, value: EmbedAuthor | None): elif isinstance(value, EmbedAuthor): self._author = value.to_dict() else: - raise TypeError( - "Expected discord.EmbedAuthor, or None but received" - f" {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected discord.EmbedAuthor, or None but received {value.__class__.__name__} instead.") def set_author( self: E, @@ -934,9 +916,7 @@ def add_field(self: E, *, name: str, value: str, inline: bool = True) -> E: return self - def insert_field_at( - self: E, index: int, *, name: Any, value: Any, inline: bool = True - ) -> E: + def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: """Inserts a field before a specified index to the embed. This function returns the class instance to allow for fluent-style @@ -989,9 +969,7 @@ def remove_field(self, index: int) -> None: except IndexError: pass - def set_field_at( - self: E, index: int, *, name: Any, value: Any, inline: bool = True - ) -> E: + def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: """Modifies a field to the embed object. The index must point to a valid pre-existing field. There must be 25 fields or fewer. @@ -1064,13 +1042,9 @@ def to_dict(self) -> EmbedData: else: if timestamp: if timestamp.tzinfo: - result["timestamp"] = timestamp.astimezone( - tz=datetime.timezone.utc - ).isoformat() + result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() else: - result["timestamp"] = timestamp.replace( - tzinfo=datetime.timezone.utc - ).isoformat() + result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() # add in the non-raw attribute ones if self.type: diff --git a/discord/emoji.py b/discord/emoji.py index 9324976f4d..5dd1eafd3d 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -49,7 +49,6 @@ class BaseEmoji(_EmojiTag, AssetMixin): - __slots__: tuple[str, ...] = ( "require_colons", "animated", @@ -172,10 +171,7 @@ def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload): super().__init__(state=state, data=data) def __repr__(self) -> str: - return ( - "" - ) + return f"" @property def roles(self) -> list[Role]: @@ -227,9 +223,7 @@ async def delete(self, *, reason: str | None = None) -> None: An error occurred deleting the emoji. """ - await self._state.http.delete_custom_emoji( - self.guild.id, self.id, reason=reason - ) + await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) async def edit( self, @@ -276,9 +270,7 @@ async def edit( if roles is not MISSING: payload["roles"] = [role.id for role in roles] - data = await self._state.http.edit_custom_emoji( - self.guild.id, self.id, payload=payload, reason=reason - ) + data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason) return GuildEmoji(guild=self.guild, data=data, state=self._state) @@ -338,14 +330,12 @@ class AppEmoji(BaseEmoji): __slots__: tuple[str, ...] = ("application_id",) - def __init__( - self, *, application_id: int, state: ConnectionState, data: EmojiPayload - ): + def __init__(self, *, application_id: int, state: ConnectionState, data: EmojiPayload): self.application_id: int = application_id super().__init__(state=state, data=data) def __repr__(self) -> str: - return "" + return f"" @property def guild(self) -> Guild: @@ -413,7 +403,5 @@ async def edit( if name is not MISSING: payload["name"] = name - data = await self._state.http.edit_application_emoji( - self.application_id, self.id, payload=payload - ) + data = await self._state.http.edit_application_emoji(self.application_id, self.id, payload=payload) return self._state.maybe_store_app_emoji(self.application_id, data) diff --git a/discord/enums.py b/discord/enums.py index 01d10275c8..85408ab802 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -86,29 +86,15 @@ def _create_value_cls(name, comparable): cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" cls.__str__ = lambda self: f"{name}.{self.name}" if comparable: - cls.__le__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value <= other.value - ) - cls.__ge__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value >= other.value - ) - cls.__lt__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value < other.value - ) - cls.__gt__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value > other.value - ) + cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value + cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value + cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value + cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls def _is_descriptor(obj): - return ( - hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") - ) + return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class EnumMeta(type): @@ -160,9 +146,7 @@ def __iter__(cls): return (cls._enum_member_map_[name] for name in cls._enum_member_names_) def __reversed__(cls): - return ( - cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_) - ) + return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) def __len__(cls): return len(cls._enum_member_names_) @@ -503,9 +487,7 @@ def category(self) -> AuditLogActionCategory | None: AuditLogAction.thread_create: AuditLogActionCategory.create, AuditLogAction.thread_update: AuditLogActionCategory.update, AuditLogAction.thread_delete: AuditLogActionCategory.delete, - AuditLogAction.application_command_permission_update: ( - AuditLogActionCategory.update - ), + AuditLogAction.application_command_permission_update: (AuditLogActionCategory.update), AuditLogAction.auto_moderation_rule_create: AuditLogActionCategory.create, AuditLogAction.auto_moderation_rule_update: AuditLogActionCategory.update, AuditLogAction.auto_moderation_rule_delete: AuditLogActionCategory.delete, @@ -809,9 +791,7 @@ def from_datatype(cls, datatype): else: raise TypeError("Invalid usage of typing.Union") - py_3_10_union_type = hasattr(types, "UnionType") and isinstance( - datatype, types.UnionType - ) + py_3_10_union_type = hasattr(types, "UnionType") and isinstance(datatype, types.UnionType) if py_3_10_union_type or getattr(datatype, "__origin__", None) is Union: # Python 3.10+ "|" operator or typing.Union has been used. The __args__ attribute is a tuple of the types. diff --git a/discord/errors.py b/discord/errors.py index 589f4f2014..0946ae221c 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -298,9 +298,7 @@ def __init__(self, message: str | None = None, *args: Any, name: str) -> None: self.name: str = name message = message or f"Extension {name!r} had an error." # clean-up @everyone and @here mentions - m = message.replace("@everyone", "@\u200beveryone").replace( - "@here", "@\u200bhere" - ) + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") super().__init__(m, *args) @@ -350,10 +348,7 @@ class ExtensionFailed(ExtensionError): def __init__(self, name: str, original: Exception) -> None: self.original: Exception = original - msg = ( - f"Extension {name!r} raised an error: {original.__class__.__name__}:" - f" {original}" - ) + msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}" super().__init__(msg, name=name) @@ -408,6 +403,4 @@ class ApplicationCommandInvokeError(ApplicationCommandError): def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__( - f"Application Command raised an exception: {e.__class__.__name__}: {e}" - ) + super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") diff --git a/discord/ext/bridge/bot.py b/discord/ext/bridge/bot.py index 9e2c619fa6..b54ef805cb 100644 --- a/discord/ext/bridge/bot.py +++ b/discord/ext/bridge/bot.py @@ -77,9 +77,7 @@ def walk_bridge_commands( if isinstance(cmd, BridgeCommandGroup): yield from cmd.walk_commands() - async def get_application_context( - self, interaction: Interaction, cls=None - ) -> BridgeApplicationContext: + async def get_application_context(self, interaction: Interaction, cls=None) -> BridgeApplicationContext: cls = cls if cls is not None else BridgeApplicationContext # Ignore the type hinting error here. BridgeApplicationContext is a subclass of ApplicationContext, and since # we gave it cls, it will be used instead. @@ -156,9 +154,7 @@ async def invoke(self, ctx: ExtContext | BridgeExtContext): if isinstance(ctx.command, BridgeExtCommand): self.dispatch("bridge_command_error", ctx, exc) - async def invoke_application_command( - self, ctx: ApplicationContext | BridgeApplicationContext - ) -> None: + async def invoke_application_command(self, ctx: ApplicationContext | BridgeApplicationContext) -> None: """|coro| Invokes the application command given under the invocation diff --git a/discord/ext/bridge/context.py b/discord/ext/bridge/context.py index 8e7f9414f6..fe1cb79305 100644 --- a/discord/ext/bridge/context.py +++ b/discord/ext/bridge/context.py @@ -67,9 +67,7 @@ async def example(ctx: BridgeContext): """ @abstractmethod - async def _respond( - self, *args, **kwargs - ) -> Interaction | WebhookMessage | Message: ... + async def _respond(self, *args, **kwargs) -> Interaction | WebhookMessage | Message: ... @abstractmethod async def _defer(self, *args, **kwargs) -> None: ... @@ -78,9 +76,7 @@ async def _defer(self, *args, **kwargs) -> None: ... async def _edit(self, *args, **kwargs) -> InteractionMessage | Message: ... @overload - async def invoke( - self, command: BridgeSlashCommand | BridgeExtCommand, *args, **kwargs - ) -> None: ... + async def invoke(self, command: BridgeSlashCommand | BridgeExtCommand, *args, **kwargs) -> None: ... async def respond(self, *args, **kwargs) -> Interaction | WebhookMessage | Message: """|coro| @@ -178,9 +174,7 @@ async def _edit(self, *args, **kwargs) -> Message | None: if self._original_response_message: return await self._original_response_message.edit(*args, **kwargs) - async def delete( - self, *, delay: float | None = None, reason: str | None = None - ) -> None: + async def delete(self, *, delay: float | None = None, reason: str | None = None) -> None: """|coro| Deletes the original response message, if it exists. diff --git a/discord/ext/bridge/core.py b/discord/ext/bridge/core.py index 9dc58fb009..345db8b6e2 100644 --- a/discord/ext/bridge/core.py +++ b/discord/ext/bridge/core.py @@ -85,9 +85,7 @@ def __init__(self, func, **kwargs): self.brief = kwargs.pop("brief", None) super().__init__(func, **kwargs) - async def dispatch_error( - self, ctx: BridgeApplicationContext, error: Exception - ) -> None: + async def dispatch_error(self, ctx: BridgeApplicationContext, error: Exception) -> None: await super().dispatch_error(ctx, error) ctx.bot.dispatch("bridge_command_error", ctx, error) @@ -100,9 +98,7 @@ def __init__(self, func, **kwargs): # TODO: v2.7: Remove backwards support for Option in bridge commands. for name, option in self.params.items(): - if isinstance(option.annotation, Option) and not isinstance( - option.annotation, BridgeOption - ): + if isinstance(option.annotation, Option) and not isinstance(option.annotation, BridgeOption): # Warn not to do this warn_deprecated( "Using Option for bridge commands", @@ -116,9 +112,7 @@ def __init__(self, func, **kwargs): # We can use the convert method from BridgeOption, and bind "self" # using a manual invocation of the descriptor protocol. # Definitely not a good approach, but gets the job done until removal. - self.params[name].annotation.convert = BridgeOption.convert.__get__( - self.params[name].annotation - ) + self.params[name].annotation.convert = BridgeOption.convert.__get__(self.params[name].annotation) async def dispatch_error(self, ctx: BridgeExtContext, error: Exception) -> None: await super().dispatch_error(ctx, error) @@ -188,12 +182,10 @@ class BridgeCommand: def __init__(self, callback, **kwargs): self.parent = kwargs.pop("parent", None) - self.slash_variant: BridgeSlashCommand = kwargs.pop( - "slash_variant", None - ) or BridgeSlashCommand(callback, **kwargs) - self.ext_variant: BridgeExtCommand = kwargs.pop( - "ext_variant", None - ) or BridgeExtCommand(callback, **kwargs) + self.slash_variant: BridgeSlashCommand = kwargs.pop("slash_variant", None) or BridgeSlashCommand( + callback, **kwargs + ) + self.ext_variant: BridgeExtCommand = kwargs.pop("ext_variant", None) or BridgeExtCommand(callback, **kwargs) @property def name_localizations(self) -> dict[str, str] | None: @@ -243,9 +235,7 @@ def __getattribute__(self, name): return getattr(self.ext_variant, name) return result except AttributeError: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") def __setattr__(self, name, value) -> None: if name not in self.__special_attrs__: @@ -265,9 +255,7 @@ def add_to(self, bot: ExtBot) -> None: bot.add_application_command(self.slash_variant) bot.add_command(self.ext_variant) - async def invoke( - self, ctx: BridgeExtContext | BridgeApplicationContext, /, *args, **kwargs - ): + async def invoke(self, ctx: BridgeExtContext | BridgeApplicationContext, /, *args, **kwargs): if ctx.is_app: return await self.slash_variant.invoke(ctx) return await self.ext_variant.invoke(ctx) @@ -428,9 +416,7 @@ def wrap(callback): **kwargs, cls=BridgeExtCommand, )(callback) - command = BridgeCommand( - callback, parent=self, slash_variant=slash, ext_variant=ext - ) + command = BridgeCommand(callback, parent=self, slash_variant=slash, ext_variant=ext) self.subcommands.append(command) return command @@ -484,12 +470,11 @@ def map_to(name, description=None): @bot.bridge_group() @bridge.map_to("show") - async def config(ctx: BridgeContext): - ... + async def config(ctx: BridgeContext): ... + @config.command() - async def toggle(ctx: BridgeContext): - ... + async def toggle(ctx: BridgeContext): ... Prefixed commands will not be affected, but slash commands will appear as: @@ -634,19 +619,14 @@ async def convert(self, ctx, argument: str) -> Any: converted = converter(argument) if self.choices: - choices_names: list[str | int | float] = [ - choice.name for choice in self.choices - ] - if converted in choices_names and ( - choice := get(self.choices, name=converted) - ): + choices_names: list[str | int | float] = [choice.name for choice in self.choices] + if converted in choices_names and (choice := get(self.choices, name=converted)): converted = choice.value else: choices = [choice.value for choice in self.choices] if converted not in choices: raise ValueError( - f"{argument} is not a valid choice. Valid choices:" - f" {list(set(choices_names + choices))}" + f"{argument} is not a valid choice. Valid choices: {list(set(choices_names + choices))}" ) return converted @@ -668,11 +648,7 @@ def bridge_option(name, input_type=None, **kwargs): def decorator(func): resolved_name = kwargs.pop("parameter_name", None) or name - itype = ( - kwargs.pop("type", None) - or input_type - or func.__annotations__.get(resolved_name, str) - ) + itype = kwargs.pop("type", None) or input_type or func.__annotations__.get(resolved_name, str) func.__annotations__[resolved_name] = BridgeOption(itype, name=name, **kwargs) return func diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index d3f0336471..c0b8f8c36d 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -40,9 +40,7 @@ Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]], ] -Hook = Union[ - Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]] -] +Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] Error = Union[ Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]], diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 149d054b78..2e038fbc92 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -84,7 +84,7 @@ def when_mentioned_or( .. code-block:: python3 - bot = commands.Bot(command_prefix=commands.when_mentioned_or('!')) + bot = commands.Bot(command_prefix=commands.when_mentioned_or("!")) .. note:: @@ -94,7 +94,7 @@ def when_mentioned_or( .. code-block:: python3 async def get_prefix(bot, message): - extras = await prefixes_for(message.guild) # returns a list + extras = await prefixes_for(message.guild) # returns a list return commands.when_mentioned_or(*extras)(bot, message) """ @@ -129,9 +129,7 @@ def __init__( ): super().__init__(**options) self.command_prefix = command_prefix - self.help_command = ( - DefaultHelpCommand() if help_command is MISSING else help_command - ) + self.help_command = DefaultHelpCommand() if help_command is MISSING else help_command self.strip_after_prefix = options.get("strip_after_prefix", False) @discord.utils.copy_doc(discord.Client.close) @@ -150,9 +148,7 @@ async def close(self) -> None: await super().close() # type: ignore - async def on_command_error( - self, context: Context, exception: errors.CommandError - ) -> None: + async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: """|coro| The default command error handler provided by the bot. @@ -174,9 +170,7 @@ async def on_command_error( return print(f"Ignoring exception in command {context.command}:", file=sys.stderr) - traceback.print_exception( - type(exception), exception, exception.__traceback__, file=sys.stderr - ) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks @@ -246,9 +240,7 @@ async def get_prefix(self, message: Message) -> list[str] | str: ) if not ret: - raise ValueError( - "Iterable command_prefix must contain at least one prefix" - ) + raise ValueError("Iterable command_prefix must contain at least one prefix") return ret @@ -306,8 +298,7 @@ class be provided, it must be similar enough to :class:`.Context`\'s except TypeError: if not isinstance(prefix, list): raise TypeError( - "get_prefix must return either a string or a list of string, " - f"not {prefix.__class__.__name__}" + f"get_prefix must return either a string or a list of string, not {prefix.__class__.__name__}" ) # It's possible a bad command_prefix got us here. diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 58bdc88d94..dd6a7c51a4 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -151,9 +151,7 @@ def __init__( self.current_parameter: inspect.Parameter | None = current_parameter self._state: ConnectionState = self.message._state - async def invoke( - self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs - ) -> T: + async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| Calls a command with the arguments given. @@ -405,7 +403,5 @@ async def reply(self, content: str | None = None, **kwargs: Any) -> Message: return await self.message.reply(content, **kwargs) @discord.utils.copy_doc(Message.forward_to) - async def forward_to( - self, channel: discord.abc.Messageable, **kwargs: Any - ) -> Message: + async def forward_to(self, channel: discord.abc.Messageable, **kwargs: Any) -> Message: return await self.message.forward_to(channel, **kwargs) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index e60ac89b34..e9352e4fce 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -157,9 +157,7 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match( - r"<(?:@[!&]?|#)([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<(?:@[!&]?|#)([0-9]{15,20})>$", argument) if match is None: raise ObjectNotFound(argument) @@ -196,9 +194,7 @@ async def query_member_named(self, guild, argument): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get( - members, name=username, discriminator=discriminator - ) + return discord.utils.get(members, name=username, discriminator=discriminator) members = await guild.query_members(argument, limit=100, cache=cache) return discord.utils.find( lambda m: argument in (m.nick, m.name, m.global_name), @@ -228,9 +224,7 @@ async def query_member_by_id(self, bot, guild, user_id): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) guild = ctx.guild result = None user_id = None @@ -285,9 +279,7 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) result = None state = ctx._state @@ -343,9 +335,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile( - r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" - ) + id_regex = re.compile(r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$") link_regex = re.compile( r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" r"(?P[0-9]{15,20}|@me)" @@ -406,9 +396,7 @@ class MessageConverter(IDConverter[discord.Message]): """ async def convert(self, ctx: Context, argument: str) -> discord.Message: - guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( - ctx, argument - ) + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: return message @@ -439,19 +427,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel( - ctx, argument, "channels", discord.abc.GuildChannel - ) + return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) @staticmethod - def _resolve_channel( - ctx: Context, argument: str, attribute: str, type: type[CT] - ) -> CT: + def _resolve_channel(ctx: Context, argument: str, attribute: str, type: type[CT]) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -479,12 +461,8 @@ def check(c): return result @staticmethod - def _resolve_thread( - ctx: Context, argument: str, attribute: str, type: type[TT] - ) -> TT: - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + def _resolve_thread(ctx: Context, argument: str, attribute: str, type: type[TT]) -> TT: + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -521,9 +499,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "text_channels", discord.TextChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -543,9 +519,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "voice_channels", discord.VoiceChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -564,9 +538,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "stage_channels", discord.StageChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -586,9 +558,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "categories", discord.CategoryChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel) class ForumChannelConverter(IDConverter[discord.ForumChannel]): @@ -607,9 +577,7 @@ class ForumChannelConverter(IDConverter[discord.ForumChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.ForumChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "forum_channels", discord.ForumChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "forum_channels", discord.ForumChannel) class ThreadConverter(IDConverter[discord.Thread]): @@ -627,9 +595,7 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread( - ctx, argument, "threads", discord.Thread - ) + return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread) class ColourConverter(Converter[discord.Colour]): @@ -658,9 +624,7 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile( - r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" - ) + RGB_REGEX = re.compile(r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)") def parse_hex_number(self, argument): arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument @@ -741,9 +705,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Role: if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match( - r"<@&([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument) if match: result = guild.get_role(int(match.group(1))) else: @@ -822,9 +784,7 @@ class EmojiConverter(IDConverter[discord.GuildEmoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.GuildEmoji: - match = self._get_id_match(argument) or re.match( - r"$", argument - ) + match = self._get_id_match(argument) or re.match(r"$", argument) result = None bot = ctx.bot guild = ctx.guild @@ -953,27 +913,17 @@ async def convert(self, ctx: Context, argument: str) -> str: if ctx.guild: def resolve_member(id: int) -> str: - m = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.guild.get_member(id) - return ( - f"@{m.display_name if self.use_nicknames else m.name}" - if m - else "@deleted-user" - ) + m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_member(id) + return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - r = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.guild.get_role(id) + r = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_role(id) return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: - m = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.bot.get_user(id) + m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.bot.get_user(id) return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: @@ -1054,11 +1004,7 @@ def __class_getitem__(cls, params: tuple[T] | T) -> Greedy[T]: origin = getattr(converter, "__origin__", None) args = getattr(converter, "__args__", ()) - if not ( - callable(converter) - or isinstance(converter, Converter) - or origin is not None - ): + if not (callable(converter) or isinstance(converter, Converter) or origin is not None): raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: @@ -1121,9 +1067,7 @@ def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool: } -async def _actual_conversion( - ctx: Context, converter, argument: str, param: inspect.Parameter -): +async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): if converter is bool: return _convert_to_bool(argument) @@ -1132,9 +1076,7 @@ async def _actual_conversion( except AttributeError: pass else: - if module is not None and ( - module.startswith("discord.") and not module.endswith("converter") - ): + if module is not None and (module.startswith("discord.") and not module.endswith("converter")): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1160,14 +1102,10 @@ async def _actual_conversion( except AttributeError: name = converter.__class__.__name__ - raise BadArgument( - f'Converting to "{name}" failed for parameter "{param.name}".' - ) from exc + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc -async def run_converters( - ctx: Context, converter, argument: str | None, param: inspect.Parameter -): +async def run_converters(ctx: Context, converter, argument: str | None, param: inspect.Parameter): """|coro| Runs converters for a given converter, argument, and parameter. diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 6e58d37f7a..94aff45c77 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -72,8 +72,7 @@ def get_key(self, msg: Message) -> Any: elif self is BucketType.category: return ( msg.channel.category.id - if isinstance(msg.channel, discord.abc.GuildChannel) - and msg.channel.category + if isinstance(msg.channel, discord.abc.GuildChannel) and msg.channel.category else msg.channel.id ) elif self is BucketType.role: @@ -198,10 +197,7 @@ def copy(self) -> Cooldown: return Cooldown(self.rate, self.per) def __repr__(self) -> str: - return ( - f"" - ) + return f"" class CooldownMapping: @@ -264,17 +260,13 @@ def get_bucket(self, message: Message, current: float | None = None) -> Cooldown return bucket - def update_rate_limit( - self, message: Message, current: float | None = None - ) -> float | None: + def update_rate_limit(self, message: Message, current: float | None = None) -> float | None: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): - def __init__( - self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] - ) -> None: + def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -364,17 +356,13 @@ def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): - raise TypeError( - f"max_concurrency 'per' must be of type BucketType not {type(per)!r}" - ) + raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: - return ( - f"" - ) + return f"" def get_key(self, message: Message) -> Any: return self.per.get_key(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 24f2973724..ac2c7a1073 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -134,9 +134,7 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: return function -def get_signature_parameters( - function: Callable[..., Any], globalns: dict[str, Any] -) -> dict[str, inspect.Parameter]: +def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, Any]) -> dict[str, inspect.Parameter]: signature = inspect.signature(function) params = {} cache: dict[str, Any] = {} @@ -321,10 +319,7 @@ def __new__(cls: type[CommandT], *args: Any, **kwargs: Any) -> CommandT: def __init__( self, - func: ( - Callable[Concatenate[CogT, ContextT, P], Coro[T]] - | Callable[Concatenate[ContextT, P], Coro[T]] - ), + func: (Callable[Concatenate[CogT, ContextT, P], Coro[T]] | Callable[Concatenate[ContextT, P], Coro[T]]), **kwargs: Any, ): if not asyncio.iscoroutinefunction(func): @@ -355,9 +350,7 @@ def __init__( self.extras: dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError( - "Aliases of a command must be a list or a tuple of strings." - ) + raise TypeError("Aliases of a command must be a list or a tuple of strings.") self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) @@ -380,9 +373,7 @@ def __init__( elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets try: @@ -420,19 +411,13 @@ def __init__( @property def callback( self, - ) -> ( - Callable[Concatenate[CogT, Context, P], Coro[T]] - | Callable[Concatenate[Context, P], Coro[T]] - ): + ) -> Callable[Concatenate[CogT, Context, P], Coro[T]] | Callable[Concatenate[Context, P], Coro[T]]: return self._callback @callback.setter def callback( self, - function: ( - Callable[Concatenate[CogT, Context, P], Coro[T]] - | Callable[Concatenate[Context, P], Coro[T]] - ), + function: (Callable[Concatenate[CogT, Context, P], Coro[T]] | Callable[Concatenate[Context, P], Coro[T]]), ) -> None: self._callback = function unwrap = unwrap_function(function) @@ -575,9 +560,7 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = default is param.empty converter = get_converter(param) - consume_rest_is_special = ( - param.kind == param.KEYWORD_ONLY and not self.rest_is_raw - ) + consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() @@ -585,13 +568,9 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: # it undoes the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos( - ctx, param, required, converter.converter - ) + return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos( - ctx, param, converter.converter - ) + return await self._transform_greedy_var_pos(ctx, param, converter.converter) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -604,10 +583,7 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: if required: if self._is_typing_optional(param.annotation): return None - if ( - hasattr(converter, "__commands_is_flag__") - and converter._can_be_constructible() - ): + if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return default @@ -651,9 +627,7 @@ async def _transform_greedy_pos( return param.default return result - async def _transform_greedy_var_pos( - self, ctx: Context, param: inspect.Parameter, converter: Any - ) -> Any: + async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: @@ -767,17 +741,13 @@ async def _parse_arguments(self, ctx: Context) -> None: try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "self" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') for name, param in iterator: ctx.current_parameter = param @@ -804,9 +774,7 @@ async def _parse_arguments(self, ctx: Context) -> None: break if not self.ignore_extra and not view.eof: - raise TooManyArguments( - f"Too many arguments passed to {self.qualified_name}" - ) + raise TooManyArguments(f"Too many arguments passed to {self.qualified_name}") async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -866,9 +834,7 @@ async def prepare(self, ctx: Context) -> None: ctx.command = self if not await self.can_run(ctx): - raise CheckFailure( - f"The check functions for command {self.qualified_name} failed." - ) + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -1083,11 +1049,8 @@ def short_doc(self) -> str: def _is_typing_optional(self, annotation: T | T | None) -> TypeGuard[T | None]: return ( - getattr(annotation, "__origin__", None) is Union - or type(annotation) is getattr(types, "UnionType", Union) - ) and type( - None - ) in annotation.__args__ # type: ignore + getattr(annotation, "__origin__", None) is Union or type(annotation) is getattr(types, "UnionType", Union) + ) and type(None) in annotation.__args__ # type: ignore @property def signature(self) -> str: @@ -1117,24 +1080,13 @@ def signature(self) -> str: origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = "|".join( - f'"{v}"' if isinstance(v, str) else str(v) - for v in annotation.__args__ - ) + name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case, and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = ( - param.default - if isinstance(param.default, str) - else param.default is not None - ) + should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append( - f"[{name}={param.default}]" - if not greedy - else f"[{name}={param.default}]..." - ) + result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") continue else: result.append(f"[{name}]") @@ -1188,10 +1140,7 @@ async def can_run(self, ctx: Context) -> bool: try: if not await ctx.bot.can_run(ctx): - raise CheckFailure( - "The global check functions for command" - f" {self.qualified_name} failed." - ) + raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") cog = self.cog if cog is not None: @@ -1229,9 +1178,7 @@ class GroupMixin(Generic[CogT]): def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get("case_insensitive", False) - self.prefixed_commands: dict[str, Command[CogT, Any, Any]] = ( - _CaseInsensitiveDict() if case_insensitive else {} - ) + self.prefixed_commands: dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1392,12 +1339,7 @@ def command( *args: Any, **kwargs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[T]] - | Callable[Concatenate[ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[T]] | Callable[Concatenate[ContextT, P], Coro[T]])], Command[CogT, P, T], ]: ... @@ -1442,12 +1384,7 @@ def group( *args: Any, **kwargs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[T]] - | Callable[Concatenate[ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[T]] | Callable[Concatenate[ContextT, P], Coro[T]])], Group[CogT, P, T], ]: ... @@ -1606,14 +1543,7 @@ def command( cls: type[Command[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P]] - | Coro[T] - | Callable[Concatenate[ContextT, P]] - | Coro[T] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P]] | Coro[T] | Callable[Concatenate[ContextT, P]] | Coro[T])], Command[CogT, P, T], ]: ... @@ -1624,12 +1554,7 @@ def command( cls: type[Command[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[T]] - | Callable[Concatenate[ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[T]] | Callable[Concatenate[ContextT, P], Coro[T]])], Command[CogT, P, T], ]: ... @@ -1640,12 +1565,7 @@ def command( cls: type[CommandT] = ..., **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[Any]] - | Callable[Concatenate[ContextT, P], Coro[Any]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[Any]] | Callable[Concatenate[ContextT, P], Coro[Any]])], CommandT, ]: ... @@ -1653,12 +1573,7 @@ def command( def command( name: str | utils.Undefined = MISSING, cls: type[CommandT] | utils.Undefined = MISSING, **attrs: Any ) -> Callable[ - [ - ( - Callable[Concatenate[ContextT, P], Coro[Any]] - | Callable[Concatenate[CogT, ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[ContextT, P], Coro[Any]] | Callable[Concatenate[CogT, ContextT, P], Coro[T]])], Command[CogT, P, T] | CommandT, ]: """A decorator that transforms a function into a :class:`.Command` @@ -1694,10 +1609,7 @@ def command( cls = Command # type: ignore def decorator( - func: ( - Callable[Concatenate[ContextT, P], Coro[Any]] - | Callable[Concatenate[CogT, ContextT, P], Coro[Any]] - ), + func: (Callable[Concatenate[ContextT, P], Coro[Any]] | Callable[Concatenate[CogT, ContextT, P], Coro[Any]]), ) -> CommandT: if isinstance(func, Command): raise TypeError("Callback is already a command.") @@ -1712,12 +1624,7 @@ def group( cls: type[Group[CogT, P, T]] = ..., **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[T]] - | Callable[Concatenate[ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[T]] | Callable[Concatenate[ContextT, P], Coro[T]])], Group[CogT, P, T], ]: ... @@ -1728,12 +1635,7 @@ def group( cls: type[GroupT] = ..., **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[CogT, ContextT, P], Coro[Any]] - | Callable[Concatenate[ContextT, P], Coro[Any]] - ) - ], + [(Callable[Concatenate[CogT, ContextT, P], Coro[Any]] | Callable[Concatenate[ContextT, P], Coro[Any]])], GroupT, ]: ... @@ -1743,12 +1645,7 @@ def group( cls: type[GroupT] | utils.Undefined = MISSING, **attrs: Any, ) -> Callable[ - [ - ( - Callable[Concatenate[ContextT, P], Coro[Any]] - | Callable[Concatenate[CogT, ContextT, P], Coro[T]] - ) - ], + [(Callable[Concatenate[ContextT, P], Coro[Any]] | Callable[Concatenate[CogT, ContextT, P], Coro[T]])], Group[CogT, P, T] | GroupT, ]: """A decorator that transforms a function into a :class:`.Group`. @@ -1786,10 +1683,12 @@ def check(predicate: Check) -> Callable[[T], T]: def owner_or_permissions(**perms): original = commands.has_permissions(**perms).predicate + async def extended_check(ctx): if ctx.guild is None: return False return ctx.guild.owner_id == ctx.author.id or await original(ctx) + return commands.check(extended_check) .. note:: @@ -1810,10 +1709,11 @@ async def extended_check(ctx): def check_if_it_is_me(ctx): return ctx.message.author.id == 85309593344815104 + @bot.command() @commands.check(check_if_it_is_me) async def only_for_me(ctx): - await ctx.send('I know you!') + await ctx.send("I know you!") Transforming common checks into its own decorator: @@ -1822,12 +1722,14 @@ async def only_for_me(ctx): def is_me(): def predicate(ctx): return ctx.message.author.id == 85309593344815104 + return commands.check(predicate) + @bot.command() @is_me() async def only_me(ctx): - await ctx.send('Only you!') + await ctx.send("Only you!") Parameters ----------- @@ -1895,12 +1797,14 @@ def check_any(*checks: Check) -> Callable[[T], T]: def is_guild_owner(): def predicate(ctx): return ctx.guild is not None and ctx.guild.owner_id == ctx.author.id + return commands.check(predicate) + @bot.command() @commands.check_any(commands.is_owner(), is_guild_owner()) async def only_for_owners(ctx): - await ctx.send('Hello mister owner!') + await ctx.send("Hello mister owner!") """ unwrapped = [] @@ -1908,9 +1812,7 @@ async def only_for_owners(ctx): try: pred = wrapped.predicate except AttributeError: - raise TypeError( - f"{wrapped!r} must be wrapped by commands.check decorator" - ) from None + raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None else: unwrapped.append(pred) @@ -2000,9 +1902,9 @@ def has_any_role(*items: int | str) -> Callable[[T], T]: .. code-block:: python3 @bot.command() - @commands.has_any_role('Library Devs', 'Moderators', 492212595072434186) + @commands.has_any_role("Library Devs", "Moderators", 492212595072434186) async def cool(ctx): - await ctx.send('You are cool indeed') + await ctx.send("You are cool indeed") """ def predicate(ctx): @@ -2012,12 +1914,7 @@ def predicate(ctx): # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any( - ( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - ) - for item in items + (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items ): return True raise MissingAnyRole(list(items)) @@ -2076,12 +1973,7 @@ def predicate(ctx): me = ctx.me getter = functools.partial(discord.utils.get, me.roles) if any( - ( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - ) - for item in items + (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items ): return True raise BotMissingAnyRole(list(items)) @@ -2117,7 +2009,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: @bot.command() @commands.has_permissions(manage_messages=True) async def test(ctx): - await ctx.send('You can manage messages.') + await ctx.send("You can manage messages.") """ @@ -2130,9 +2022,7 @@ def predicate(ctx: Context) -> bool: return True permissions = ctx.channel.permissions_for(ctx.author) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2165,9 +2055,7 @@ def predicate(ctx: Context) -> bool: else: permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2196,9 +2084,7 @@ def predicate(ctx: Context) -> bool: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2224,9 +2110,7 @@ def predicate(ctx: Context) -> bool: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2304,9 +2188,7 @@ def is_nsfw() -> Callable[[T], T]: def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or ( - isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() - ): + if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True raise NSFWChannelRequired(ch) # type: ignore @@ -2399,9 +2281,7 @@ def decorator(func: Command | CoroFunc) -> Command | CoroFunc: return decorator # type: ignore -def max_concurrency( - number: int, per: BucketType = BucketType.default, *, wait: bool = False -) -> Callable[[T], T]: +def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a command This enables you to only allow a certain number of command invocations at the same time, @@ -2450,27 +2330,29 @@ def before_invoke(coro) -> Callable[[T], T]: .. code-block:: python3 async def record_usage(ctx): - print(ctx.author, 'used', ctx.command, 'at', ctx.message.created_at) + print(ctx.author, "used", ctx.command, "at", ctx.message.created_at) + @bot.command() @commands.before_invoke(record_usage) - async def who(ctx): # Output: used who at