From 5c2f54feb5147e2d919e2e8e3ec6fb36f086e5b9 Mon Sep 17 00:00:00 2001 From: Bradley Reynolds Date: Wed, 27 Mar 2024 21:23:59 -0500 Subject: [PATCH 01/16] Set global HTTP session timeout to 30s (#230) Hopefully this helps fix https://vipyrsec.sentry.io/issues/5052320225 Signed-off-by: Bradley Reynolds --- src/bot/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bot/__main__.py b/src/bot/__main__.py index 2cd4468..9e92a9e 100644 --- a/src/bot/__main__.py +++ b/src/bot/__main__.py @@ -27,7 +27,7 @@ def get_prefix(bot_: Bot, message_: discord.Message) -> Callable[[Bot, discord.M async def main() -> None: """Run the bot.""" - async with ClientSession(headers={"Content-Type": "application/json"}, timeout=ClientTimeout(total=10)) as session: + async with ClientSession(headers={"Content-Type": "application/json"}, timeout=ClientTimeout(total=30)) as session: dragonfly_services = DragonflyServices( session=session, base_url=constants.Dragonfly.base_url, From b8493891d2c4fc6df4d38b4ac3de4b2e683e72fa Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Fri, 24 May 2024 00:53:18 -0500 Subject: [PATCH 02/16] Add configuration options --- src/bot/constants.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/bot/constants.py b/src/bot/constants.py index 44312d5..b7e12af 100644 --- a/src/bot/constants.py +++ b/src/bot/constants.py @@ -33,6 +33,18 @@ class _Miscellaneous(EnvConfig): Miscellaneous = _Miscellaneous() + +class _ThreatIntelFeed(EnvConfig, env_prefix="tif_"): + """Threat Intelligence Feed Configuration.""" + + repository: str = "pypi/pypi-observation-reports-private" + interval: int = 10 * 60 # 10 minutes + access_token: str = "" + channel_id: int = 1121471544355455058 + + +ThreatIntelFeed = _ThreatIntelFeed() + FILE_LOGS = Miscellaneous.file_logs DEBUG_MODE = Miscellaneous.debug From 5056d5826069e53a3a0d6662380bb31c10a0ac0a Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Fri, 24 May 2024 00:53:30 -0500 Subject: [PATCH 03/16] Add threat intelligence feed cog --- src/bot/exts/dragonfly/threat_intel_feed.py | 166 ++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/bot/exts/dragonfly/threat_intel_feed.py diff --git a/src/bot/exts/dragonfly/threat_intel_feed.py b/src/bot/exts/dragonfly/threat_intel_feed.py new file mode 100644 index 0000000..4f90cea --- /dev/null +++ b/src/bot/exts/dragonfly/threat_intel_feed.py @@ -0,0 +1,166 @@ +"""Threat Intelligence Feed Cog.""" + +import json +import logging +import re +from io import BytesIO +from logging import getLogger +from typing import Any +from zipfile import ZipFile + +import aiohttp +import discord +from discord.ext import commands, tasks + +from bot import constants +from bot.bot import Bot +from bot.dragonfly_services import PackageScanResult + +log = getLogger(__name__) +log.setLevel(logging.INFO) + +_p = re.compile(r"https://inspector.pypi.io/project/(?P\w+)/(?P[\w.]+)/.*") + + +def build_github_link_from_path(path: str) -> str: + """Build a GitHub link to the given path.""" + segments = path.split("/") + path = "/".join(segments[1:]) + + return f"https://github.com/{constants.ThreatIntelFeed.repository}/blob/main/{path}" + + +def parse_package_info_from_inspector_url(inspector_url: str) -> tuple[str, str] | None: + """Return a tuple of package name and version, parsed from the inspector URL. None if it couldn't be parsed.""" + if g := _p.match(inspector_url): + name = g.group("name") + version = g.group("version") + + return name, version + + return None + + +def search(d: dict, key: Any) -> Any | None: # noqa: ANN401 - we can't know the type of the dict ahead of time + """Recursively search for the first occurence of a key in a dict. None if not found.""" + for k, v in d.items(): + if k == key: + return v + + if isinstance(v, dict) and (val := search(v, key)): + return val + + return None + + +def build_embed(package: PackageScanResult, path: str, inspector_url: str) -> discord.Embed: + """Return the embed to be sent in the threat intelligence feed.""" + if package.reported_at: + ts = discord.utils.format_dt(package.reported_at, style="F") + description = f"We already reported this package at {ts}" + color = discord.Color.green() + else: + description = f"We didn't catch this package! Here are our matched rules: ```{', '.join(package.rules)}```" + color = discord.Colour.red() + + embed = discord.Embed( + title=f"New Report: {package.name} v{package.version}", + description=description, + color=color, + url=build_github_link_from_path(path), + ) + + embed.add_field(name="Inspector URL", value=f"[Inspector URL]({inspector_url})") + + return embed + + +def build_package_not_found_embed(name: str, version: str, path: str) -> discord.Embed: + """Return the embed for when a report was filed for a package which we don't have records for.""" + return discord.Embed( + title="Package not found!", + description=( + f"A report was filed for {name} v{version}, " + "however we don't have any records for this package in our database. " + "This means that we are missing packages, please investigate this!" + ), + color=discord.Color.red(), + url=build_github_link_from_path(path), + ) + + +async def fetch_zipfile(http_client: aiohttp.ClientSession) -> ZipFile: + """Download the source zipfile from GitHub for the feed source repository.""" + url = f"https://api.github.com/repos/{constants.ThreatIntelFeed.repository}/zipball" + headers = {"Authorization": f"Bearer {constants.ThreatIntelFeed.access_token}"} + + async with http_client.get(url, headers=headers) as res: + res.raise_for_status() + b = await res.content.read() + + buffer = BytesIO(b) + return ZipFile(buffer) + + +class ThreatIntelFeed(commands.Cog): + """Threat Intelligence Feed Cog.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.reports_seen: set[str] = set() + + @tasks.loop(seconds=constants.ThreatIntelFeed.interval) + async def watcher(self) -> None: + """Watch the GitHub repository for changes.""" + zipfile = await fetch_zipfile(self.bot.http_session) + paths = {path for path in zipfile.namelist() if path.endswith(".json")} + + channel = self.bot.get_channel(constants.ThreatIntelFeed.channel_id) + if not isinstance(channel, discord.abc.Messageable): + log.error("Threat intel feed channel is not messageable") + return + + # The first time around, just add all the reports to our "seen reports" set + if len(self.reports_seen) == 0: + self.reports_seen |= paths + return + + for path in paths: + if path in self.reports_seen: + continue + + content = json.loads(zipfile.read(path).decode()) + inspector_url: str | None = search(content, "inspector_url") + if not inspector_url: + log.error("Inspector URL not found in %s, skipping", path) + continue + + match parse_package_info_from_inspector_url(inspector_url): + case name, version: + results = await self.bot.dragonfly_services.get_scanned_packages(name=name, version=version) + package = results[0] if results else None + + if package: + embed = build_embed(package, path, inspector_url) + else: + embed = build_package_not_found_embed(name, version, path) + + await channel.send(embed=embed) + + case None: + log.error('Unable to parse inspector URL: "%s" in %s, skipping', inspector_url, path) + continue + + @watcher.before_loop + async def before_watcher(self) -> None: + """Before first task run hook.""" + await self.bot.wait_until_ready() + + +async def setup(bot: Bot) -> None: + """Extension setup.""" + cog = ThreatIntelFeed(bot) + task = cog.watcher + if not task.is_running: + task.start() + await bot.add_cog(cog) From b897f4a8585f89dc86fea2a0afdcb4a6f9c889db Mon Sep 17 00:00:00 2001 From: Robin <74519799+Robin5605@users.noreply.github.com> Date: Fri, 24 May 2024 19:58:36 -0500 Subject: [PATCH 04/16] Change default interval length Change the default interval to fetch the repository contents to one hour Co-authored-by: Rem <128343390+import-pandas-as-numpy@users.noreply.github.com> Signed-off-by: Robin <74519799+Robin5605@users.noreply.github.com> --- src/bot/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bot/constants.py b/src/bot/constants.py index b7e12af..052928e 100644 --- a/src/bot/constants.py +++ b/src/bot/constants.py @@ -38,7 +38,7 @@ class _ThreatIntelFeed(EnvConfig, env_prefix="tif_"): """Threat Intelligence Feed Configuration.""" repository: str = "pypi/pypi-observation-reports-private" - interval: int = 10 * 60 # 10 minutes + interval: int = 60 * 60 # 1 hour access_token: str = "" channel_id: int = 1121471544355455058 From 59dd2a86064db54687aa2eff9e2ca042d03a92c3 Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Thu, 30 May 2024 22:27:15 +0100 Subject: [PATCH 05/16] Update dragonfly.py Updated view for package alerts, adding a triage system #242 Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 289 ++++++++++++++++++++++++++-- 1 file changed, 275 insertions(+), 14 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 45ce50b..3ad2551 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -1,5 +1,7 @@ """Download the most recent packages from PyPI and use Dragonfly to check them for malware.""" +from __future__ import annotations + import logging from datetime import UTC, datetime, timedelta from logging import getLogger @@ -127,9 +129,13 @@ async def on_error(self: Self, interaction: discord.Interaction, error: Exceptio f"Retry using Observation API instead?" ) view = ReportMethodSwitchConfirmationView(previous_modal=self) - return await interaction.response.send_message(message, view=view, ephemeral=True) + return await interaction.response.send_message( + message, view=view, ephemeral=True + ) - await interaction.response.send_message("An unexpected error occured.", ephemeral=True) + await interaction.response.send_message( + "An unexpected error occured.", ephemeral=True + ) raise error async def on_submit(self: Self, interaction: discord.Interaction) -> None: @@ -143,7 +149,11 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None: use_email=True, ) - await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services) + await handle_submit( + report=report, + interaction=interaction, + dragonfly_services=self.bot.dragonfly_services, + ) class ConfirmReportModal(discord.ui.Modal): @@ -175,14 +185,20 @@ def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: super().__init__() - async def on_error(self: Self, interaction: discord.Interaction, error: Exception) -> None: + async def on_error( + self: Self, interaction: discord.Interaction, error: Exception + ) -> None: """Handle errors that occur in the modal.""" if isinstance(error, aiohttp.ClientResponseError): message = f"Error from upstream: {error.status}\n```{error.message}```\nRetry using email instead?" view = ReportMethodSwitchConfirmationView(previous_modal=self) - return await interaction.response.send_message(message, view=view, ephemeral=True) + return await interaction.response.send_message( + message, view=view, ephemeral=True + ) - await interaction.response.send_message("An unexpected error occured.", ephemeral=True) + await interaction.response.send_message( + "An unexpected error occured.", ephemeral=True + ) raise error async def on_submit(self: Self, interaction: discord.Interaction) -> None: @@ -196,7 +212,11 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None: use_email=False, ) - await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services) + await handle_submit( + report=report, + interaction=interaction, + dragonfly_services=self.bot.dragonfly_services, + ) class ReportMethodSwitchConfirmationView(discord.ui.View): @@ -206,14 +226,18 @@ class ReportMethodSwitchConfirmationView(discord.ui.View): user if they want to switch to another method of sending reports. """ - def __init__(self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal) -> None: + def __init__( + self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal + ) -> None: super().__init__() self.previous_modal = previous_modal self.package = previous_modal.package self.bot = previous_modal.bot @discord.ui.button(label="Yes", style=discord.ButtonStyle.green) - async def confirm(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None: + async def confirm( + self: Self, interaction: discord.Interaction, _button: discord.ui.Button + ) -> None: """Confirm button callback.""" if isinstance(self.previous_modal, ConfirmReportModal): modal = ConfirmEmailReportModal(package=self.package, bot=self.bot) @@ -226,7 +250,9 @@ async def confirm(self: Self, interaction: discord.Interaction, _button: discord await interaction.edit_original_response(view=self) @discord.ui.button(label="No, retry the operation", style=discord.ButtonStyle.red) - async def cancel(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None: + async def cancel( + self: Self, interaction: discord.Interaction, _button: discord.ui.Button + ) -> None: """Cancel button callback.""" modal = type(self.previous_modal)(package=self.package, bot=self.bot) @@ -261,6 +287,201 @@ async def report(self: Self, interaction: discord.Interaction, button: discord.u await interaction.edit_original_response(view=self) +class NoteModal(discord.ui.Modal, title="Add a note"): + """A modal that allows users to add a note to a package""" + + _interaction: discord.Interaction | None = None + note_content = discord.ui.TextInput( + label="Content", + placeholder="Enter the note content here", + min_length=1, + max_length=1000, # Don't want to overfill the embed + ) + + def __init__(self, embed: discord.Embed, view: discord.ui.View): + super().__init__() + + self.embed = embed + self.view = view + + async def on_submit(self, interaction: discord.Interaction) -> None: + if not interaction.response.is_done(): + await interaction.response.defer() + self._interaction = interaction + + content = f"{self.note_content.value} • {interaction.user.mention}" + + # We need to check what fields the embed has to determine where to add the note + # If the embed has no fields, we add the note and return + # Otherwise, we need to make sure the note is added after the event log + # This involves clearing the fields and re-adding them in the correct order + # Which is why we save the event log in a variable + + match len(self.embed.fields): + case 0: # Package is awaiting triage, no notes or event log + notes = [content] + event_log = None + case 1: # Package either has notes or event log + if self.embed.fields[0].name == "Notes": + notes = [self.embed.fields[0].value, content] + else: + event_log = self.embed.fields[0].value + notes = [content] + self.embed.clear_fields() + case 2: # Package has both notes and event log + if self.embed.fields[0].name == "Notes": + notes = [self.embed.fields[0].value, content] + event_log = self.embed.fields[1].value + else: + notes = [self.embed.fields[1].value, content] + event_log = self.embed.fields[0].value + self.embed.clear_fields() + + self.embed.add_field(name="Notes", value="\n".join(notes), inline=False) + + if event_log: + self.embed.add_field(name="Event log", value=event_log, inline=False) + + await interaction.message.edit(embed=self.embed, view=self.view) + + async def on_error( + self, interaction: discord.Interaction, error: Exception + ) -> None: + + await interaction.response.send_message( + "An unexpected error occured.", ephemeral=True + ) + raise error + + @property + def interaction(self) -> discord.Interaction | None: + return self._interaction + + +class MalwareView(discord.ui.View): + """View for the malware triage system""" + + message: discord.Message | None = None + + def __init__( + self: Self, embed: discord.Embed, bot: Bot, payload: PackageScanResult + ) -> None: + self.embed = embed + self.bot = bot + self.payload = payload + self.event_log = [] + + super().__init__() + + async def enable_button(self, button_label: str) -> None: + for button in self.children: + if button.label == button_label: + button.disabled = False + + async def add_event(self, message: str) -> None: + # Much like earlier, we need to check the fields of the embed to determine where to add the event log + match len(self.embed.fields): + case 0: + pass + case 1: + if self.embed.fields[0].name == "Event log": + self.embed.clear_fields() + case 2: + if self.embed.fields[0].name == "Event log": + self.embed.clear_fields() + elif self.embed.fields[1].name == "Event log": + self.embed.remove_field(1) + + self.event_log.append( + message + ) # For future reference, we save the event log in a variable + self.embed.add_field( + name="Event log", value="\n".join(self.event_log), inline=False + ) + + async def update_status(self, status: str) -> None: + self.embed.set_footer(text=status) + + def get_timestamp( + self, + ) -> ( + int + ): # This function returns the current timestamp in Discord's timestamp format + return f"" + + @discord.ui.button(label="Report", style=discord.ButtonStyle.red) + async def report(self: Self, interaction: discord.Interaction, button: discord.ui.Button) -> None: # type: ignore[type-arg] + """Report a package.""" + modal = ConfirmReportModal(package=self.payload, bot=self.bot) + await interaction.response.send_modal(modal) + + timed_out = await modal.wait() + if not timed_out: + button.disabled = True + await interaction.edit_original_response(view=self) + + @discord.ui.button( + label="Report", + style=discord.ButtonStyle.red, + ) + async def report( + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + ) -> None: + await self.enable_button("Approve") + await self.add_event( + f"Reported by {interaction.user.mention} • {self.get_timestamp()}" + ) + await self.update_status("Flagged as malicious") + + self.embed.color = discord.Color.red() + + modal = ConfirmReportModal(package=self.payload, bot=self.bot) + await interaction.response.send_modal(modal) + + timed_out = await modal.wait() + if not timed_out: + button.disabled = True + await interaction.edit_original_response(view=self, embed=self.embed) + + @discord.ui.button( + label="Approve", + style=discord.ButtonStyle.green, + ) + async def approve( + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + ) -> None: + await self.enable_button("Report") + await self.add_event( + f"Approved by {interaction.user.mention} • {self.get_timestamp()}" + ) + await self.update_status("Flagged as benign") + + button.disabled = True + + self.embed.color = discord.Color.green() + await interaction.response.edit_message(view=self, embed=self.embed) + + @discord.ui.button( + label="Add note", + style=discord.ButtonStyle.grey, + ) + async def add_note( + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + ) -> None: + await interaction.response.send_modal(NoteModal(embed=self.embed, view=self)) + + async def on_error( + self, + interaction: discord.Interaction[discord.Client], + error: Exception, + ) -> None: + + await interaction.response.send_message( + "An unexpected error occured.", ephemeral=True + ) + raise error + + def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord.Embed: """Build the embed that shows the results of a package scan.""" condition = scan_result.score >= DragonflyConfig.threshold @@ -287,7 +508,31 @@ def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord. return embed -def _build_all_packages_scanned_embed(scan_results: list[PackageScanResult]) -> discord.Embed: +def _build_package_scan_result_triage_embed( + scan_result: PackageScanResult, +) -> discord.Embed: + """Build the embed for the malware triage system""" + + embed = discord.Embed( + title="View on Inspector", + description="\n".join(scan_result.rules), + url=scan_result.inspector_url, + color=discord.Color.orange(), + timestamp=datetime.now(UTC), + ) + embed.set_author( + name=f"{scan_result.name}@{scan_result.version}", + url=f"https://pypi.org/project/{scan_result.name}/{scan_result.version}", + icon_url="https://seeklogo.com/images/P/pypi-logo-5B953CE804-seeklogo.com.png", + ) + embed.set_footer(text="Awaiting triage") + + return embed + + +def _build_all_packages_scanned_embed( + scan_results: list[PackageScanResult], +) -> discord.Embed: """Build the embed that shows a list of all packages scanned.""" if scan_results: description = "\n".join(map(str, scan_results)) @@ -307,12 +552,22 @@ async def run( scan_results = await bot.dragonfly_services.get_scanned_packages(since=since) for result in scan_results: if result.score >= score: + """ embed = _build_package_scan_result_embed(result) await alerts_channel.send( f"<@&{DragonflyConfig.alerts_role_id}>", embed=embed, view=ReportView(bot, result), ) + """ + embed = _build_package_scan_result_triage_embed(result) + view = MalwareView(embed=embed, bot=bot, payload=result) + + view.message = await alerts_channel.send( + f"<@&{DragonflyConfig.alerts_role_id}>", + embed=embed, + view=view, + ) await logs_channel.send(embed=_build_all_packages_scanned_embed(scan_results)) @@ -330,7 +585,9 @@ def __init__(self: Self, bot: Bot) -> None: @commands.hybrid_command(name="username") # type: ignore [arg-type] async def get_username_command(self, ctx: commands.Context[Bot]) -> None: """Get the username of the currently logged in user to the PyPI Observation API.""" - async with ctx.bot.http_session.get(DragonflyConfig.reporter_url + "/echo") as res: + async with ctx.bot.http_session.get( + DragonflyConfig.reporter_url + "/echo" + ) as res: json = await res.json() username = json["username"] @@ -392,12 +649,16 @@ async def stop(self: Self, ctx: commands.Context, force: bool = False) -> None: @discord.app_commands.command(name="lookup", description="Scans a package") async def lookup(self: Self, interaction: discord.Interaction, name: str, version: str | None = None) -> None: # type: ignore[type-arg] """Pull the scan results for a package.""" - scan_results = await self.bot.dragonfly_services.get_scanned_packages(name=name, version=version) + scan_results = await self.bot.dragonfly_services.get_scanned_packages( + name=name, version=version + ) if scan_results: embed = _build_package_scan_result_embed(scan_results[0]) await interaction.response.send_message(embed=embed) else: - await interaction.response.send_message("No entries were found with the specified filters.") + await interaction.response.send_message( + "No entries were found with the specified filters." + ) @commands.group() async def threshold(self: Self, ctx: commands.Context) -> None: # type: ignore[type-arg] From c6e355e2479b45d449fdb328dcb6b180c53beced Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Fri, 31 May 2024 16:30:00 +0100 Subject: [PATCH 06/16] Linted with ruff * Removed type annotation * Removed extra remove method * Removed commented out method Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 100 ++++++++++++---------------- 1 file changed, 44 insertions(+), 56 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 3ad2551..41cae32 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -1,20 +1,19 @@ """Download the most recent packages from PyPI and use Dragonfly to check them for malware.""" -from __future__ import annotations - import logging from datetime import UTC, datetime, timedelta from logging import getLogger -from typing import Self +from typing import TYPE_CHECKING, Self import aiohttp import discord import sentry_sdk -from discord.ext import commands, tasks - -from bot.bot import Bot from bot.constants import Channels, DragonflyConfig, Roles from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult +from discord.ext import commands, tasks + +if TYPE_CHECKING: + from bot.bot import Bot log = getLogger(__name__) log.setLevel(logging.INFO) @@ -130,11 +129,11 @@ async def on_error(self: Self, interaction: discord.Interaction, error: Exceptio ) view = ReportMethodSwitchConfirmationView(previous_modal=self) return await interaction.response.send_message( - message, view=view, ephemeral=True + message, view=view, ephemeral=True, ) await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True + "An unexpected error occured.", ephemeral=True, ) raise error @@ -186,18 +185,18 @@ def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: super().__init__() async def on_error( - self: Self, interaction: discord.Interaction, error: Exception + self: Self, interaction: discord.Interaction, error: Exception, ) -> None: """Handle errors that occur in the modal.""" if isinstance(error, aiohttp.ClientResponseError): message = f"Error from upstream: {error.status}\n```{error.message}```\nRetry using email instead?" view = ReportMethodSwitchConfirmationView(previous_modal=self) return await interaction.response.send_message( - message, view=view, ephemeral=True + message, view=view, ephemeral=True, ) await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True + "An unexpected error occured.", ephemeral=True, ) raise error @@ -227,7 +226,7 @@ class ReportMethodSwitchConfirmationView(discord.ui.View): """ def __init__( - self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal + self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal, ) -> None: super().__init__() self.previous_modal = previous_modal @@ -236,7 +235,7 @@ def __init__( @discord.ui.button(label="Yes", style=discord.ButtonStyle.green) async def confirm( - self: Self, interaction: discord.Interaction, _button: discord.ui.Button + self: Self, interaction: discord.Interaction, _button: discord.ui.Button, ) -> None: """Confirm button callback.""" if isinstance(self.previous_modal, ConfirmReportModal): @@ -251,7 +250,7 @@ async def confirm( @discord.ui.button(label="No, retry the operation", style=discord.ButtonStyle.red) async def cancel( - self: Self, interaction: discord.Interaction, _button: discord.ui.Button + self: Self, interaction: discord.Interaction, _button: discord.ui.Button, ) -> None: """Cancel button callback.""" modal = type(self.previous_modal)(package=self.package, bot=self.bot) @@ -288,7 +287,7 @@ async def report(self: Self, interaction: discord.Interaction, button: discord.u class NoteModal(discord.ui.Modal, title="Add a note"): - """A modal that allows users to add a note to a package""" + """A modal that allows users to add a note to a package.""" _interaction: discord.Interaction | None = None note_content = discord.ui.TextInput( @@ -298,13 +297,14 @@ class NoteModal(discord.ui.Modal, title="Add a note"): max_length=1000, # Don't want to overfill the embed ) - def __init__(self, embed: discord.Embed, view: discord.ui.View): + def __init__(self, embed: discord.Embed, view: discord.ui.View) -> None: super().__init__() self.embed = embed self.view = view async def on_submit(self, interaction: discord.Interaction) -> None: + """Modal submit callback.""" if not interaction.response.is_done(): await interaction.response.defer() self._interaction = interaction @@ -345,26 +345,27 @@ async def on_submit(self, interaction: discord.Interaction) -> None: await interaction.message.edit(embed=self.embed, view=self.view) async def on_error( - self, interaction: discord.Interaction, error: Exception + self, interaction: discord.Interaction, error: Exception, ) -> None: - + """Handle errors that occur in the modal.""" await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True + "An unexpected error occured.", ephemeral=True, ) raise error @property def interaction(self) -> discord.Interaction | None: + """Get the interaction that triggered the modal.""" return self._interaction class MalwareView(discord.ui.View): - """View for the malware triage system""" + """View for the malware triage system.""" message: discord.Message | None = None def __init__( - self: Self, embed: discord.Embed, bot: Bot, payload: PackageScanResult + self: Self, embed: discord.Embed, bot: Bot, payload: PackageScanResult, ) -> None: self.embed = embed self.bot = bot @@ -374,11 +375,13 @@ def __init__( super().__init__() async def enable_button(self, button_label: str) -> None: + """Enables a button by its label.""" for button in self.children: if button.label == button_label: button.disabled = False async def add_event(self, message: str) -> None: + """Add an event to the event log.""" # Much like earlier, we need to check the fields of the embed to determine where to add the event log match len(self.embed.fields): case 0: @@ -393,43 +396,35 @@ async def add_event(self, message: str) -> None: self.embed.remove_field(1) self.event_log.append( - message + message, ) # For future reference, we save the event log in a variable self.embed.add_field( - name="Event log", value="\n".join(self.event_log), inline=False + name="Event log", value="\n".join(self.event_log), inline=False, ) async def update_status(self, status: str) -> None: + """Update the status of the package in the embed.""" self.embed.set_footer(text=status) def get_timestamp( self, ) -> ( int - ): # This function returns the current timestamp in Discord's timestamp format - return f"" - - @discord.ui.button(label="Report", style=discord.ButtonStyle.red) - async def report(self: Self, interaction: discord.Interaction, button: discord.ui.Button) -> None: # type: ignore[type-arg] - """Report a package.""" - modal = ConfirmReportModal(package=self.payload, bot=self.bot) - await interaction.response.send_modal(modal) - - timed_out = await modal.wait() - if not timed_out: - button.disabled = True - await interaction.edit_original_response(view=self) + ): + """Returns the current timestamp in seconds.""" + return f"" @discord.ui.button( label="Report", style=discord.ButtonStyle.red, ) async def report( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], ) -> None: + """Report package and update the embed.""" await self.enable_button("Approve") await self.add_event( - f"Reported by {interaction.user.mention} • {self.get_timestamp()}" + f"Reported by {interaction.user.mention} • {self.get_timestamp()}", ) await self.update_status("Flagged as malicious") @@ -448,11 +443,12 @@ async def report( style=discord.ButtonStyle.green, ) async def approve( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], ) -> None: + """Approve package and update the embed.""" await self.enable_button("Report") await self.add_event( - f"Approved by {interaction.user.mention} • {self.get_timestamp()}" + f"Approved by {interaction.user.mention} • {self.get_timestamp()}", ) await self.update_status("Flagged as benign") @@ -466,8 +462,9 @@ async def approve( style=discord.ButtonStyle.grey, ) async def add_note( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView] + self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], ) -> None: + """Add note to the embed.""" await interaction.response.send_modal(NoteModal(embed=self.embed, view=self)) async def on_error( @@ -475,9 +472,9 @@ async def on_error( interaction: discord.Interaction[discord.Client], error: Exception, ) -> None: - + """Handle errors that occur in the view.""" await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True + "An unexpected error occured.", ephemeral=True, ) raise error @@ -511,8 +508,7 @@ def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord. def _build_package_scan_result_triage_embed( scan_result: PackageScanResult, ) -> discord.Embed: - """Build the embed for the malware triage system""" - + """Build the embed for the malware triage system.""" embed = discord.Embed( title="View on Inspector", description="\n".join(scan_result.rules), @@ -552,14 +548,6 @@ async def run( scan_results = await bot.dragonfly_services.get_scanned_packages(since=since) for result in scan_results: if result.score >= score: - """ - embed = _build_package_scan_result_embed(result) - await alerts_channel.send( - f"<@&{DragonflyConfig.alerts_role_id}>", - embed=embed, - view=ReportView(bot, result), - ) - """ embed = _build_package_scan_result_triage_embed(result) view = MalwareView(embed=embed, bot=bot, payload=result) @@ -586,7 +574,7 @@ def __init__(self: Self, bot: Bot) -> None: async def get_username_command(self, ctx: commands.Context[Bot]) -> None: """Get the username of the currently logged in user to the PyPI Observation API.""" async with ctx.bot.http_session.get( - DragonflyConfig.reporter_url + "/echo" + DragonflyConfig.reporter_url + "/echo", ) as res: json = await res.json() username = json["username"] @@ -650,14 +638,14 @@ async def stop(self: Self, ctx: commands.Context, force: bool = False) -> None: async def lookup(self: Self, interaction: discord.Interaction, name: str, version: str | None = None) -> None: # type: ignore[type-arg] """Pull the scan results for a package.""" scan_results = await self.bot.dragonfly_services.get_scanned_packages( - name=name, version=version + name=name, version=version, ) if scan_results: embed = _build_package_scan_result_embed(scan_results[0]) await interaction.response.send_message(embed=embed) else: await interaction.response.send_message( - "No entries were found with the specified filters." + "No entries were found with the specified filters.", ) @commands.group() From ce891c59d921d5270546208d68012ecb7c125057 Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Fri, 31 May 2024 16:39:25 +0100 Subject: [PATCH 07/16] Fixed line length changes Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 46 ++++++++--------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 41cae32..165b0ee 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -1,5 +1,7 @@ """Download the most recent packages from PyPI and use Dragonfly to check them for malware.""" +from __future__ import annotations + import logging from datetime import UTC, datetime, timedelta from logging import getLogger @@ -128,9 +130,7 @@ async def on_error(self: Self, interaction: discord.Interaction, error: Exceptio f"Retry using Observation API instead?" ) view = ReportMethodSwitchConfirmationView(previous_modal=self) - return await interaction.response.send_message( - message, view=view, ephemeral=True, - ) + return await interaction.response.send_message(message, view=view, ephemeral=True) await interaction.response.send_message( "An unexpected error occured.", ephemeral=True, @@ -148,11 +148,7 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None: use_email=True, ) - await handle_submit( - report=report, - interaction=interaction, - dragonfly_services=self.bot.dragonfly_services, - ) + await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services) class ConfirmReportModal(discord.ui.Modal): @@ -184,20 +180,14 @@ def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: super().__init__() - async def on_error( - self: Self, interaction: discord.Interaction, error: Exception, - ) -> None: + async def on_error(self: Self, interaction: discord.Interaction, error: Exception) -> None: """Handle errors that occur in the modal.""" if isinstance(error, aiohttp.ClientResponseError): message = f"Error from upstream: {error.status}\n```{error.message}```\nRetry using email instead?" view = ReportMethodSwitchConfirmationView(previous_modal=self) - return await interaction.response.send_message( - message, view=view, ephemeral=True, - ) + return await interaction.response.send_message(message, view=view, ephemeral=True) - await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True, - ) + await interaction.response.send_message("An unexpected error occured.", ephemeral=True) raise error async def on_submit(self: Self, interaction: discord.Interaction) -> None: @@ -211,11 +201,7 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None: use_email=False, ) - await handle_submit( - report=report, - interaction=interaction, - dragonfly_services=self.bot.dragonfly_services, - ) + await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services) class ReportMethodSwitchConfirmationView(discord.ui.View): @@ -225,18 +211,14 @@ class ReportMethodSwitchConfirmationView(discord.ui.View): user if they want to switch to another method of sending reports. """ - def __init__( - self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal, - ) -> None: + def __init__(self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal) -> None: super().__init__() self.previous_modal = previous_modal self.package = previous_modal.package self.bot = previous_modal.bot @discord.ui.button(label="Yes", style=discord.ButtonStyle.green) - async def confirm( - self: Self, interaction: discord.Interaction, _button: discord.ui.Button, - ) -> None: + async def confirm(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None: """Confirm button callback.""" if isinstance(self.previous_modal, ConfirmReportModal): modal = ConfirmEmailReportModal(package=self.package, bot=self.bot) @@ -249,9 +231,7 @@ async def confirm( await interaction.edit_original_response(view=self) @discord.ui.button(label="No, retry the operation", style=discord.ButtonStyle.red) - async def cancel( - self: Self, interaction: discord.Interaction, _button: discord.ui.Button, - ) -> None: + async def cancel(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None: """Cancel button callback.""" modal = type(self.previous_modal)(package=self.package, bot=self.bot) @@ -526,9 +506,7 @@ def _build_package_scan_result_triage_embed( return embed -def _build_all_packages_scanned_embed( - scan_results: list[PackageScanResult], -) -> discord.Embed: +def _build_all_packages_scanned_embed(scan_results: list[PackageScanResult]) -> discord.Embed: """Build the embed that shows a list of all packages scanned.""" if scan_results: description = "\n".join(map(str, scan_results)) From 2d45e06cd43906ad785ce166575a6ff1ef44e678 Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Fri, 31 May 2024 16:40:59 +0100 Subject: [PATCH 08/16] Missed a couple Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 165b0ee..2801c4f 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -132,9 +132,7 @@ async def on_error(self: Self, interaction: discord.Interaction, error: Exceptio view = ReportMethodSwitchConfirmationView(previous_modal=self) return await interaction.response.send_message(message, view=view, ephemeral=True) - await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True, - ) + await interaction.response.send_message("An unexpected error occured.", ephemeral=True) raise error async def on_submit(self: Self, interaction: discord.Interaction) -> None: @@ -551,9 +549,7 @@ def __init__(self: Self, bot: Bot) -> None: @commands.hybrid_command(name="username") # type: ignore [arg-type] async def get_username_command(self, ctx: commands.Context[Bot]) -> None: """Get the username of the currently logged in user to the PyPI Observation API.""" - async with ctx.bot.http_session.get( - DragonflyConfig.reporter_url + "/echo", - ) as res: + async with ctx.bot.http_session.get(DragonflyConfig.reporter_url + "/echo") as res: json = await res.json() username = json["username"] @@ -615,16 +611,12 @@ async def stop(self: Self, ctx: commands.Context, force: bool = False) -> None: @discord.app_commands.command(name="lookup", description="Scans a package") async def lookup(self: Self, interaction: discord.Interaction, name: str, version: str | None = None) -> None: # type: ignore[type-arg] """Pull the scan results for a package.""" - scan_results = await self.bot.dragonfly_services.get_scanned_packages( - name=name, version=version, - ) + scan_results = await self.bot.dragonfly_services.get_scanned_packages(name=name, version=version) if scan_results: embed = _build_package_scan_result_embed(scan_results[0]) await interaction.response.send_message(embed=embed) else: - await interaction.response.send_message( - "No entries were found with the specified filters.", - ) + await interaction.response.send_message("No entries were found with the specified filters.") @commands.group() async def threshold(self: Self, ctx: commands.Context) -> None: # type: ignore[type-arg] From 6cf3f1da84bc2338585bd7aa2e04c8903c8b621b Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Fri, 31 May 2024 17:07:39 +0100 Subject: [PATCH 09/16] Removed annotations Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 2801c4f..b8ce5d4 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -1,7 +1,5 @@ """Download the most recent packages from PyPI and use Dragonfly to check them for malware.""" -from __future__ import annotations - import logging from datetime import UTC, datetime, timedelta from logging import getLogger From 4de511862b8c8d349d891cb96e394c29ac1a9f2a Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Sun, 2 Jun 2024 23:29:51 +0100 Subject: [PATCH 10/16] Update src/bot/exts/dragonfly/dragonfly.py Co-authored-by: Robin <74519799+Robin5605@users.noreply.github.com> Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index b8ce5d4..10c088f 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -387,7 +387,7 @@ def get_timestamp( ) -> ( int ): - """Returns the current timestamp in seconds.""" + """Returns the current timestamp, formatted in Discord's relative style""" return f"" @discord.ui.button( From c21d60f9053bb597f0ca5f445f78310b28db4382 Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Sun, 2 Jun 2024 23:31:47 +0100 Subject: [PATCH 11/16] Update src/bot/exts/dragonfly/dragonfly.py Co-authored-by: Robin <74519799+Robin5605@users.noreply.github.com> Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 10c088f..0773542 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -398,7 +398,7 @@ async def report( self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], ) -> None: """Report package and update the embed.""" - await self.enable_button("Approve") + self.approve.disabled = False await self.add_event( f"Reported by {interaction.user.mention} • {self.get_timestamp()}", ) From ea8e5841623ee835d482617bc5811f307a69f39e Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Sun, 2 Jun 2024 23:33:33 +0100 Subject: [PATCH 12/16] Updated to remove enable button function Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 0773542..5b34314 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -329,11 +329,6 @@ async def on_error( ) raise error - @property - def interaction(self) -> discord.Interaction | None: - """Get the interaction that triggered the modal.""" - return self._interaction - class MalwareView(discord.ui.View): """View for the malware triage system.""" @@ -422,7 +417,7 @@ async def approve( self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], ) -> None: """Approve package and update the embed.""" - await self.enable_button("Report") + self.report.disabled = False await self.add_event( f"Approved by {interaction.user.mention} • {self.get_timestamp()}", ) From 4f2a63617eb6d5d43e1c213d5c9199b27cbb87cc Mon Sep 17 00:00:00 2001 From: Jayy001 Date: Sun, 2 Jun 2024 23:36:12 +0100 Subject: [PATCH 13/16] Added timestamp Signed-off-by: Jayy001 --- src/bot/exts/dragonfly/dragonfly.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 5b34314..9611927 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -11,6 +11,7 @@ from bot.constants import Channels, DragonflyConfig, Roles from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult from discord.ext import commands, tasks +from discord.utils import format_dt if TYPE_CHECKING: from bot.bot import Bot @@ -383,7 +384,7 @@ def get_timestamp( int ): """Returns the current timestamp, formatted in Discord's relative style""" - return f"" + return format_dt(datetime.now(UTC), style="R") @discord.ui.button( label="Report", From e5451d0c11ba270409aa8e42ecf120a04e864bd3 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sun, 2 Jun 2024 17:52:14 -0500 Subject: [PATCH 14/16] Fix incorrect return type for method --- src/bot/exts/dragonfly/dragonfly.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 9611927..dd44016 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -378,11 +378,7 @@ async def update_status(self, status: str) -> None: """Update the status of the package in the embed.""" self.embed.set_footer(text=status) - def get_timestamp( - self, - ) -> ( - int - ): + def get_timestamp(self) -> str: """Returns the current timestamp, formatted in Discord's relative style""" return format_dt(datetime.now(UTC), style="R") From 285700f75b7f0fa59f1cbdbbb843dc5d0688af32 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sun, 2 Jun 2024 18:57:04 -0500 Subject: [PATCH 15/16] Lint --- src/bot/exts/dragonfly/dragonfly.py | 49 +++++++++++++++++------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index dd44016..6921478 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -3,18 +3,17 @@ import logging from datetime import UTC, datetime, timedelta from logging import getLogger -from typing import TYPE_CHECKING, Self +from typing import Self import aiohttp import discord import sentry_sdk -from bot.constants import Channels, DragonflyConfig, Roles -from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult from discord.ext import commands, tasks from discord.utils import format_dt -if TYPE_CHECKING: - from bot.bot import Bot +from bot.bot import Bot +from bot.constants import Channels, DragonflyConfig, Roles +from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult log = getLogger(__name__) log.setLevel(logging.INFO) @@ -322,11 +321,14 @@ async def on_submit(self, interaction: discord.Interaction) -> None: await interaction.message.edit(embed=self.embed, view=self.view) async def on_error( - self, interaction: discord.Interaction, error: Exception, + self, + interaction: discord.Interaction, + error: Exception, ) -> None: """Handle errors that occur in the modal.""" await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True, + "An unexpected error occured.", + ephemeral=True, ) raise error @@ -337,7 +339,10 @@ class MalwareView(discord.ui.View): message: discord.Message | None = None def __init__( - self: Self, embed: discord.Embed, bot: Bot, payload: PackageScanResult, + self: Self, + embed: discord.Embed, + bot: Bot, + payload: PackageScanResult, ) -> None: self.embed = embed self.bot = bot @@ -346,12 +351,6 @@ def __init__( super().__init__() - async def enable_button(self, button_label: str) -> None: - """Enables a button by its label.""" - for button in self.children: - if button.label == button_label: - button.disabled = False - async def add_event(self, message: str) -> None: """Add an event to the event log.""" # Much like earlier, we need to check the fields of the embed to determine where to add the event log @@ -371,7 +370,9 @@ async def add_event(self, message: str) -> None: message, ) # For future reference, we save the event log in a variable self.embed.add_field( - name="Event log", value="\n".join(self.event_log), inline=False, + name="Event log", + value="\n".join(self.event_log), + inline=False, ) async def update_status(self, status: str) -> None: @@ -379,7 +380,7 @@ async def update_status(self, status: str) -> None: self.embed.set_footer(text=status) def get_timestamp(self) -> str: - """Returns the current timestamp, formatted in Discord's relative style""" + """Return the current timestamp, formatted in Discord's relative style.""" return format_dt(datetime.now(UTC), style="R") @discord.ui.button( @@ -387,7 +388,9 @@ def get_timestamp(self) -> str: style=discord.ButtonStyle.red, ) async def report( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], + self, + interaction: discord.Interaction, + button: discord.ui.Button, ) -> None: """Report package and update the embed.""" self.approve.disabled = False @@ -411,7 +414,9 @@ async def report( style=discord.ButtonStyle.green, ) async def approve( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], + self, + interaction: discord.Interaction, + button: discord.ui.Button, ) -> None: """Approve package and update the embed.""" self.report.disabled = False @@ -430,7 +435,9 @@ async def approve( style=discord.ButtonStyle.grey, ) async def add_note( - self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView], + self, + interaction: discord.Interaction, + _button: discord.ui.Button, ) -> None: """Add note to the embed.""" await interaction.response.send_modal(NoteModal(embed=self.embed, view=self)) @@ -439,10 +446,12 @@ async def on_error( self, interaction: discord.Interaction[discord.Client], error: Exception, + _item: discord.ui.Item, ) -> None: """Handle errors that occur in the view.""" await interaction.response.send_message( - "An unexpected error occured.", ephemeral=True, + "An unexpected error occured.", + ephemeral=True, ) raise error From c16095720300e2fdae61aa6716b16e051bcbf008 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Tue, 4 Jun 2024 20:58:33 -0500 Subject: [PATCH 16/16] Conform to new API response model Upstream has changed their response model for the `GET /package` endpoint - this PR aims to be compliant with that new contract. --- src/bot/dragonfly_services.py | 57 +++++++++------------ src/bot/exts/audit.py | 6 +-- src/bot/exts/dragonfly/dragonfly.py | 18 +++---- src/bot/exts/dragonfly/threat_intel_feed.py | 4 +- 4 files changed, 37 insertions(+), 48 deletions(-) diff --git a/src/bot/dragonfly_services.py b/src/bot/dragonfly_services.py index 5d2e6a4..259b80b 100644 --- a/src/bot/dragonfly_services.py +++ b/src/bot/dragonfly_services.py @@ -7,6 +7,7 @@ from typing import Any, Self from aiohttp import ClientSession +from pydantic import BaseModel class ScanStatus(Enum): @@ -18,41 +19,29 @@ class ScanStatus(Enum): FAILED = "failed" -@dataclass -class PackageScanResult: - """A package scan result.""" +class Package(BaseModel): + """Model representing a package queried from the database.""" - status: ScanStatus - inspector_url: str - queued_at: datetime + scan_id: str + name: str + version: str + status: ScanStatus | None + score: int | None + inspector_url: str | None + rules: list[str] = [] + download_urls: list[str] = [] + queued_at: datetime | None + queued_by: str | None + reported_at: datetime | None + reported_by: str | None pending_at: datetime | None + pending_by: str | None finished_at: datetime | None - reported_at: datetime | None - version: str - name: str - package_id: str - rules: list[str] - score: int - - @classmethod - def from_dict(cls: type[Self], data: dict) -> Self: # type: ignore[type-arg] - """Create a PackageScanResult from a dictionary.""" - return cls( - status=ScanStatus(data["status"]), - inspector_url=data["inspector_url"], - queued_at=datetime.fromisoformat(data["queued_at"]), - pending_at=datetime.fromisoformat(p) if (p := data["pending_at"]) else None, - finished_at=datetime.fromisoformat(p) if (p := data["finished_at"]) else None, - reported_at=datetime.fromisoformat(p) if (p := data["reported_at"]) else None, - version=data["version"], - name=data["name"], - package_id=data["scan_id"], - rules=[d["name"] for d in data["rules"]], - score=int(data["score"]), - ) - - def __str__(self: Self) -> str: - """Return a string representation of the package scan result.""" + finished_by: str | None + commit_hash: str | None + + def __str__(self) -> str: + """Return package name and version.""" return f"{self.name} {self.version}" @@ -146,7 +135,7 @@ async def get_scanned_packages( name: str | None = None, version: str | None = None, since: datetime | None = None, - ) -> list[PackageScanResult]: + ) -> list[Package]: """Get a list of scanned packages.""" params = {} if name: @@ -159,7 +148,7 @@ async def get_scanned_packages( params["since"] = int(since.timestamp()) # type: ignore[assignment] data = await self.make_request("GET", "/package", params=params) - return [PackageScanResult.from_dict(dct) for dct in data] + return list(map(Package.model_validate, data)) async def report_package( self: Self, diff --git a/src/bot/exts/audit.py b/src/bot/exts/audit.py index 39f83f9..a47b635 100644 --- a/src/bot/exts/audit.py +++ b/src/bot/exts/audit.py @@ -10,7 +10,7 @@ from discord.ext import commands from bot.bot import Bot -from bot.dragonfly_services import PackageScanResult +from bot.dragonfly_services import Package class PaginatorView(ui.View): @@ -20,7 +20,7 @@ def __init__( self: Self, *, member: discord.Member | discord.User, - packages: list[PackageScanResult], + packages: list[Package], per: int = 15, ) -> None: """Initialize the paginator view.""" @@ -70,7 +70,7 @@ async def interaction_check(self: Self, interaction: discord.Interaction) -> boo await interaction.response.send_message("This paginator is not for you!", ephemeral=True) return False - def _build_embed(self: Self, packages: list[PackageScanResult], page: int, total: int) -> discord.Embed: + def _build_embed(self: Self, packages: list[Package], page: int, total: int) -> discord.Embed: """Build an embed for the given packages.""" embed = discord.Embed( title="Package Audit", diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 6921478..ad9152c 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -13,7 +13,7 @@ from bot.bot import Bot from bot.constants import Channels, DragonflyConfig, Roles -from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult +from bot.dragonfly_services import DragonflyServices, Package, PackageReport log = getLogger(__name__) log.setLevel(logging.INFO) @@ -108,7 +108,7 @@ class ConfirmEmailReportModal(discord.ui.Modal): style=discord.TextStyle.short, ) - def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: + def __init__(self: Self, *, package: Package, bot: Bot) -> None: """Initialize the modal.""" self.package = package self.bot = bot @@ -164,7 +164,7 @@ class ConfirmReportModal(discord.ui.Modal): style=discord.TextStyle.short, ) - def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: + def __init__(self: Self, *, package: Package, bot: Bot) -> None: """Initialize the modal.""" self.package = package self.bot = bot @@ -245,7 +245,7 @@ def disable_all(self: Self) -> None: class ReportView(discord.ui.View): """Report view.""" - def __init__(self: Self, bot: Bot, payload: PackageScanResult) -> None: + def __init__(self: Self, bot: Bot, payload: Package) -> None: self.bot = bot self.payload = payload super().__init__(timeout=None) @@ -342,7 +342,7 @@ def __init__( self: Self, embed: discord.Embed, bot: Bot, - payload: PackageScanResult, + payload: Package, ) -> None: self.embed = embed self.bot = bot @@ -456,9 +456,9 @@ async def on_error( raise error -def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord.Embed: +def _build_package_scan_result_embed(scan_result: Package) -> discord.Embed: """Build the embed that shows the results of a package scan.""" - condition = scan_result.score >= DragonflyConfig.threshold + condition = (scan_result.score or 0) >= DragonflyConfig.threshold title, color = ("Malicious", 0xF70606) if condition else ("Benign", 0x4CBB17) embed = discord.Embed( @@ -483,7 +483,7 @@ def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord. def _build_package_scan_result_triage_embed( - scan_result: PackageScanResult, + scan_result: Package, ) -> discord.Embed: """Build the embed for the malware triage system.""" embed = discord.Embed( @@ -503,7 +503,7 @@ def _build_package_scan_result_triage_embed( return embed -def _build_all_packages_scanned_embed(scan_results: list[PackageScanResult]) -> discord.Embed: +def _build_all_packages_scanned_embed(scan_results: list[Package]) -> discord.Embed: """Build the embed that shows a list of all packages scanned.""" if scan_results: description = "\n".join(map(str, scan_results)) diff --git a/src/bot/exts/dragonfly/threat_intel_feed.py b/src/bot/exts/dragonfly/threat_intel_feed.py index 4f90cea..9f44d7f 100644 --- a/src/bot/exts/dragonfly/threat_intel_feed.py +++ b/src/bot/exts/dragonfly/threat_intel_feed.py @@ -14,7 +14,7 @@ from bot import constants from bot.bot import Bot -from bot.dragonfly_services import PackageScanResult +from bot.dragonfly_services import Package log = getLogger(__name__) log.setLevel(logging.INFO) @@ -53,7 +53,7 @@ def search(d: dict, key: Any) -> Any | None: # noqa: ANN401 - we can't know the return None -def build_embed(package: PackageScanResult, path: str, inspector_url: str) -> discord.Embed: +def build_embed(package: Package, path: str, inspector_url: str) -> discord.Embed: """Return the embed to be sent in the threat intelligence feed.""" if package.reported_at: ts = discord.utils.format_dt(package.reported_at, style="F")