From c268012cfa9c255c448c4a67b346175e39a84375 Mon Sep 17 00:00:00 2001 From: AbdelbassitAb Date: Tue, 21 Jan 2025 17:38:08 +0100 Subject: [PATCH 1/2] change the full outpot view from link to button --- bot/exts/utils/snekbox/_cog.py | 151 +++++++++++++++++++++++---------- 1 file changed, 108 insertions(+), 43 deletions(-) diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index 39f61c6e26..1e38dbf20d 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -8,8 +8,19 @@ from textwrap import dedent from typing import Literal, NamedTuple, TYPE_CHECKING, get_args -from discord import AllowedMentions, HTTPException, Interaction, Message, NotFound, Reaction, User, enums, ui +from discord import ( + AllowedMentions, + HTTPException, + Interaction, + Message, + NotFound, + Reaction, + User, + enums, + ui, +) from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only +from discord.ui import Button from pydis_core.utils import interactions, paste_service from pydis_core.utils.paste_service import PasteFile, send_to_paste_service from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX @@ -82,13 +93,21 @@ def print_last_line(): # The Snekbox commands' whitelists and blacklists. NO_SNEKBOX_CHANNELS = (Channels.python_general,) NO_SNEKBOX_CATEGORIES = () -SNEKBOX_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) +SNEKBOX_ROLES = ( + Roles.helpers, + Roles.moderators, + Roles.admins, + Roles.owners, + Roles.python_community, + Roles.partners, +) REDO_EMOJI = "\U0001f501" # :repeat: REDO_TIMEOUT = 30 SupportedPythonVersions = Literal["3.12", "3.13", "3.13t"] + class FilteredFiles(NamedTuple): allowed: list[FileAttachment] blocked: list[FileAttachment] @@ -119,7 +138,9 @@ async def convert(cls, ctx: Context, code: str) -> list[str]: code, block, lang, delim = match.group("code", "block", "lang", "delim") codeblocks = [dedent(code)] if block: - info = (f"'{lang}' highlighted" if lang else "plain") + " code block" + info = ( + f"'{lang}' highlighted" if lang else "plain" + ) + " code block" else: info = f"{delim}-enclosed inline code" else: @@ -142,7 +163,9 @@ def __init__( job: EvalJob, ) -> None: self.version_to_run = version_to_run - super().__init__(label=f"Run in {self.version_to_run}", style=enums.ButtonStyle.primary) + super().__init__( + label=f"Run in {self.version_to_run}", style=enums.ButtonStyle.primary + ) self.snekbox_cog = snekbox_cog self.ctx = ctx @@ -163,7 +186,9 @@ async def callback(self, interaction: Interaction) -> None: # The log arg on send_job will stop the actual job from running. await interaction.message.delete() - await self.snekbox_cog.run_job(self.ctx, self.job.as_version(self.version_to_run)) + await self.snekbox_cog.run_job( + self.ctx, self.job.as_version(self.version_to_run) + ) class Snekbox(Cog): @@ -197,7 +222,9 @@ async def post_job(self, job: EvalJob) -> EvalResult: """Send a POST request to the Snekbox API to evaluate code and return the results.""" data = job.to_dict() - async with self.bot.http_session.post(URLs.snekbox_eval_api, json=data, raise_for_status=True) as resp: + async with self.bot.http_session.post( + URLs.snekbox_eval_api, json=data, raise_for_status=True + ) as resp: return EvalResult.from_dict(await resp.json()) async def upload_output(self, output: str) -> str | None: @@ -257,7 +284,10 @@ async def format_output( if ESCAPE_REGEX.findall(output): paste_link = await self.upload_output(original_output) - return "Code block escape attempt detected; will not output result", paste_link + return ( + "Code block escape attempt detected; will not output result", + paste_link, + ) truncated = False lines = output.splitlines() @@ -269,12 +299,14 @@ async def format_output( if len(lines) > max_lines: truncated = True if len(lines) == max_lines + 1: - lines = lines[:max_lines - 1] + lines = lines[: max_lines - 1] else: lines = lines[:max_lines] output = "\n".join(lines) if len(output) >= max_chars: - output = f"{output[:max_chars]}\n... (truncated - too long, too many lines)" + output = ( + f"{output[:max_chars]}\n... (truncated - too long, too many lines)" + ) else: output = f"{output}\n... (truncated - too many lines)" elif len(output) >= max_chars: @@ -292,7 +324,9 @@ async def format_output( return output, paste_link - async def format_file_text(self, text_files: list[FileAttachment], output: str) -> str: + async def format_file_text( + self, text_files: list[FileAttachment], output: str + ) -> str: # Inline until budget, then upload to paste service # Budget is shared with stdout, so subtract what we've already used budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) @@ -311,7 +345,7 @@ async def format_file_text(self, text_files: list[FileAttachment], output: str) budget_lines, budget_chars, line_nums=False, - output_default="[Empty]" + output_default="[Empty]", ) # With any link, use it (don't use budget) if link_text: @@ -325,24 +359,30 @@ async def format_file_text(self, text_files: list[FileAttachment], output: str) def format_blocked_extensions(self, blocked: list[FileAttachment]) -> str: # Sort by length and then lexicographically to fit as many as possible before truncating. - blocked_sorted = sorted(set(f.suffix for f in blocked), key=lambda e: (len(e), e)) + blocked_sorted = sorted( + set(f.suffix for f in blocked), key=lambda e: (len(e), e) + ) # Only no extension if len(blocked_sorted) == 1 and blocked_sorted[0] == "": blocked_msg = "Files with no extension can't be uploaded." # Both elif "" in blocked_sorted: - blocked_str = self.join_blocked_extensions(ext for ext in blocked_sorted if ext) - blocked_msg = ( - f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + blocked_str = self.join_blocked_extensions( + ext for ext in blocked_sorted if ext ) + blocked_msg = f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" else: blocked_str = self.join_blocked_extensions(blocked_sorted) - blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + blocked_msg = ( + f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + ) return f"\n{Emojis.failed_file} {blocked_msg}" - def join_blocked_extensions(self, extensions: Iterable[str], delimiter: str = ", ", char_limit: int = 100) -> str: + def join_blocked_extensions( + self, extensions: Iterable[str], delimiter: str = ", ", char_limit: int = 100 + ) -> str: joined = "" for ext in extensions: cur_delimiter = delimiter if joined else "" @@ -354,8 +394,9 @@ def join_blocked_extensions(self, extensions: Iterable[str], delimiter: str = ", return joined - - def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles: + def _filter_files( + self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str] + ) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" # Filter files into allowed and blocked blocked = [] @@ -370,7 +411,7 @@ def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: blocked_str = ", ".join(f.suffix for f in blocked) log.info( f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", - extra={"attachment_list": [f.filename for f in files]} + extra={"attachment_list": [f.filename for f in files]}, ) return FilteredFiles(allowed, blocked) @@ -399,16 +440,18 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message: # This is done to make sure the last line of output contains the error # and the error is not manually printed by the author with a syntax error. - if result.stdout.rstrip().endswith("EOFError: EOF when reading a line") and result.returncode == 1: - msg += "\n:warning: Note: `input` is not supported by the bot :warning:\n" + if ( + result.stdout.rstrip().endswith("EOFError: EOF when reading a line") + and result.returncode == 1 + ): + msg += ( + "\n:warning: Note: `input` is not supported by the bot :warning:\n" + ) # Skip output if it's empty and there are file uploads if result.stdout or not result.has_files: msg += f"\n```ansi\n{output}\n```" - if paste_link: - msg += f"\nFull output: {paste_link}" - # Additional files error message after output if files_error := result.files_error_message: msg += f"\n{files_error}" @@ -423,9 +466,13 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message: failed_files = [FileAttachment(name, b"") for name in result.failed_files] total_files = result.files + failed_files if filter_cog: - block_output, blocked_exts = await filter_cog.filter_snekbox_output(msg, total_files, ctx.message) + block_output, blocked_exts = await filter_cog.filter_snekbox_output( + msg, total_files, ctx.message + ) if block_output: - return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + return await ctx.send( + "Attempt to circumvent filter detected. Moderator team has been alerted." + ) # Filter file extensions allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) @@ -435,8 +482,18 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message: # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] - allowed_mentions = AllowedMentions(everyone=False, roles=False, users=[ctx.author]) + allowed_mentions = AllowedMentions( + everyone=False, roles=False, users=[ctx.author] + ) view = self.build_python_version_switcher_view(job.version, ctx, job) + if paste_link: + # Create a button + button = Button( + label="View Full Output", # Button text + url=paste_link, # The URL the button links to + ) + + view.add_item(button) if ctx.message.channel == ctx.channel: # Don't fail if the command invoking message was deleted. @@ -446,15 +503,19 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message: allowed_mentions=allowed_mentions, view=view, files=files, - reference=message + reference=message, ) else: # The command was redirected so a reply wont work, send a normal message with a mention. msg = f"{ctx.author.mention} {msg}" - response = await ctx.send(msg, allowed_mentions=allowed_mentions, view=view, files=files) + response = await ctx.send( + msg, allowed_mentions=allowed_mentions, view=view, files=files + ) view.message = response - log.info(f"{ctx.author}'s {job.name} job had a return code of {result.returncode}") + log.info( + f"{ctx.author}'s {job.name} job had a return code of {result.returncode}" + ) return response async def continue_job( @@ -472,15 +533,11 @@ async def continue_job( with contextlib.suppress(NotFound): try: _, new_message = await self.bot.wait_for( - "message_edit", - check=_predicate_message_edit, - timeout=REDO_TIMEOUT + "message_edit", check=_predicate_message_edit, timeout=REDO_TIMEOUT ) await ctx.message.add_reaction(REDO_EMOJI) await self.bot.wait_for( - "reaction_add", - check=_predicate_emoji_reaction, - timeout=10 + "reaction_add", check=_predicate_emoji_reaction, timeout=10 ) # Ensure the response that's about to be edited is still the most recent. @@ -576,14 +633,14 @@ async def run_job( bypass_roles=SNEKBOX_ROLES, categories=NO_SNEKBOX_CATEGORIES, channels=NO_SNEKBOX_CHANNELS, - ping_user=False + ping_user=False, ) async def eval_command( self, ctx: Context, python_version: SupportedPythonVersions | None, *, - code: CodeblockConverter + code: CodeblockConverter, ) -> None: """ Run Python code and get the results. @@ -608,21 +665,25 @@ async def eval_command( job = EvalJob.from_code("\n".join(code)).as_version(python_version) await self.run_job(ctx, job) - @command(name="timeit", aliases=("ti",), usage="[python_version] [setup_code] ") + @command( + name="timeit", + aliases=("ti",), + usage="[python_version] [setup_code] ", + ) @guild_only() @redirect_output( destination_channel=Channels.bot_commands, bypass_roles=SNEKBOX_ROLES, categories=NO_SNEKBOX_CATEGORIES, channels=NO_SNEKBOX_CHANNELS, - ping_user=False + ping_user=False, ) async def timeit_command( self, ctx: Context, python_version: SupportedPythonVersions | None, *, - code: CodeblockConverter + code: CodeblockConverter, ) -> None: """ Profile Python Code to find execution time. @@ -654,4 +715,8 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" - return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI + return ( + reaction.message.id == ctx.message.id + and user.id == ctx.author.id + and str(reaction) == REDO_EMOJI + ) From 4bcbd50bcbb36bba812ba2bfdc555b5c24c5b707 Mon Sep 17 00:00:00 2001 From: AbdelbassitAb Date: Mon, 27 Jan 2025 23:21:26 +0100 Subject: [PATCH 2/2] fix test file for the output button --- tests/bot/exts/utils/snekbox/test_snekbox.py | 269 +++++++++++++------ 1 file changed, 191 insertions(+), 78 deletions(-) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9cfd75df8b..fe53a5c895 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -3,7 +3,7 @@ from base64 import b64encode from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from discord import AllowedMentions +from discord import AllowedMentions, ui from discord.ext import commands from pydis_core.utils.paste_service import MAX_PASTE_SIZE @@ -12,7 +12,14 @@ from bot.exts.utils import snekbox from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox from bot.exts.utils.snekbox._io import FileAttachment -from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser +from tests.helpers import ( + MockBot, + MockContext, + MockMember, + MockMessage, + MockReaction, + MockUser, +) class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -25,12 +32,14 @@ def setUp(self): @staticmethod def code_args(code: str) -> tuple[EvalJob]: """Converts code to a tuple of arguments expected.""" - return EvalJob.from_code(code), + return (EvalJob.from_code(code),) async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() - resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) + resp.json = AsyncMock( + return_value={"stdout": "Hi", "returncode": 137, "files": []} + ) context_manager = MagicMock() context_manager.__aenter__.return_value = resp @@ -50,9 +59,7 @@ async def test_post_job(self): "executable_path": f"/snekbin/python/{py_version}/bin/python", } self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json=expected, - raise_for_status=True + constants.URLs.snekbox_eval_api, json=expected, raise_for_status=True ) resp.json.assert_awaited_once() @@ -70,20 +77,42 @@ async def test_codeblock_converter(self): cases = ( ('print("Hello world!")', 'print("Hello world!")', "non-formatted"), ('`print("Hello world!")`', 'print("Hello world!")', "one line code block"), - ('```\nprint("Hello world!")```', 'print("Hello world!")', "multiline code block"), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', "multiline python code block"), - ('text```print("Hello world!")```text', 'print("Hello world!")', "code block surrounded by text"), - ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', - 'print("Hello world!")\nprint("Hello world!")', "two code blocks with text in-between"), - ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', - 'print("How\'s it going?")', "code block preceded by inline code"), - ('`print("Hello world!")`\ntext\n`print("Hello world!")`', - 'print("Hello world!")', "one inline code block of two") + ( + '```\nprint("Hello world!")```', + 'print("Hello world!")', + "multiline code block", + ), + ( + '```py\nprint("Hello world!")```', + 'print("Hello world!")', + "multiline python code block", + ), + ( + 'text```print("Hello world!")```text', + 'print("Hello world!")', + "code block surrounded by text", + ), + ( + '```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', + 'print("Hello world!")\nprint("Hello world!")', + "two code blocks with text in-between", + ), + ( + '`print("Hello world!")`\ntext\n```print("How\'s it going?")```', + 'print("How\'s it going?")', + "code block preceded by inline code", + ), + ( + '`print("Hello world!")`\ntext\n`print("Hello world!")`', + 'print("Hello world!")', + "one inline code block of two", + ), ) for case, expected, testname in cases: with self.subTest(msg=f"Extract code from {testname}."): self.assertEqual( - "\n".join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + "\n".join(await snekbox.CodeblockConverter.convert(ctx, case)), + expected, ) def test_prepare_timeit_input(self): @@ -92,21 +121,35 @@ def test_prepare_timeit_input(self): cases = ( (['print("Hello World")'], "", "single block of code"), (["x = 1", "print(x)"], "x = 1", "two blocks of code"), - (["x = 1", "print(x)", 'print("Some other code.")'], "x = 1", "three blocks of code") + ( + ["x = 1", "print(x)", 'print("Some other code.")'], + "x = 1", + "three blocks of code", + ), ) for case, setup_code, test_name in cases: setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) expected = [*base_args, setup, "\n".join(case[1:] if setup_code else case)] - with self.subTest(msg=f"Test with {test_name} and expected return {expected}"): + with self.subTest( + msg=f"Test with {test_name} and expected return {expected}" + ): self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_eval_result_message(self): """EvalResult.get_message(), should return message.""" cases = ( ("ERROR", None, ("Your 3.12 eval job has failed", "ERROR", "")), - ("", 128 + snekbox._eval.SIGKILL, ("Your 3.12 eval job timed out or ran out of memory", "", "")), - ("", 255, ("Your 3.12 eval job has failed", "A fatal NsJail error occurred", "")) + ( + "", + 128 + snekbox._eval.SIGKILL, + ("Your 3.12 eval job timed out or ran out of memory", "", ""), + ), + ( + "", + 255, + ("Your 3.12 eval job has failed", "A fatal NsJail error occurred", ""), + ), ) for stdout, returncode, expected in cases: exp_msg, exp_err, exp_files_err = expected @@ -125,21 +168,33 @@ def test_eval_result_message(self): def test_eval_result_files_error_message(self): """EvalResult.files_error_message, should return files error message.""" cases = [ - ([], ["abc"], ( - "1 file upload (abc) failed because its file size exceeds 8 MiB." - )), - ([], ["file1.bin", "f2.bin"], ( - "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." - )), - (["a", "b"], ["c"], ( - "1 file upload (c) failed as it exceeded the 2 file limit." - )), - (["a"], ["b", "c"], ( - "2 file uploads (b, c) failed as they exceeded the 2 file limit." - )), + ( + [], + ["abc"], + ("1 file upload (abc) failed because its file size exceeds 8 MiB."), + ), + ( + [], + ["file1.bin", "f2.bin"], + ( + "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." + ), + ), + ( + ["a", "b"], + ["c"], + ("1 file upload (c) failed as it exceeded the 2 file limit."), + ), + ( + ["a"], + ["b", "c"], + ("2 file uploads (b, c) failed as they exceeded the 2 file limit."), + ), ] for files, failed_files, expected_msg in cases: - with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): + with self.subTest( + files=files, failed_files=failed_files, expected_msg=expected_msg + ): result = EvalResult("", 0, files, failed_files) msg = result.files_error_message self.assertIn(expected_msg, msg) @@ -168,7 +223,7 @@ def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_status_message(EvalJob([], version="3.10")), - "Your 3.10 eval job has completed with return code 127" + "Your 3.10 eval job has completed with return code 127", ) self.assertEqual(result.error_message, "") self.assertEqual(result.files_error_message, "") @@ -179,7 +234,7 @@ def test_eval_result_message_valid_signal(self, mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_status_message(EvalJob([], version="3.12")), - "Your 3.12 eval job has completed with return code 127 (SIGTEST)" + "Your 3.12 eval job has completed with return code 127 (SIGTEST)", ) def test_eval_result_status_emoji(self): @@ -187,7 +242,7 @@ def test_eval_result_status_emoji(self): cases = ( (" ", -1, ":warning:"), ("Hello world!", 0, ":white_check_mark:"), - ("Invalid beard size", -1, ":x:") + ("Invalid beard size", -1, ":x:"), ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): @@ -204,8 +259,10 @@ async def test_format_output(self): ) too_long_too_many_lines = ( "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(["verylongbeard" * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" + f"{i:03d} | {line}" + for i, line in enumerate(["verylongbeard" * 10] * 15, 1) + )[:1000] + + "\n... (truncated - too long, too many lines)" ) cases = ( @@ -215,29 +272,38 @@ async def test_format_output(self): ("