Skip to content

Commit

Permalink
Update dragonfly.py
Browse files Browse the repository at this point in the history
Updated view for package alerts, adding a triage system #242 

Signed-off-by: Jayy001 <[email protected]>
  • Loading branch information
Jayy001 authored May 30, 2024
1 parent 5c2f54f commit 59dd2a8
Showing 1 changed file with 275 additions and 14 deletions.
289 changes: 275 additions & 14 deletions src/bot/exts/dragonfly/dragonfly.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Download the most recent packages from PyPI and use Dragonfly to check them for malware."""

from __future__ import annotations

import logging
from datetime import UTC, datetime, timedelta
from logging import getLogger
Expand Down Expand Up @@ -127,9 +129,13 @@ async def on_error(self: Self, interaction: discord.Interaction, error: Exceptio
f"Retry using Observation API instead?"
)
view = ReportMethodSwitchConfirmationView(previous_modal=self)
return await interaction.response.send_message(message, view=view, ephemeral=True)
return await interaction.response.send_message(
message, view=view, ephemeral=True
)

await interaction.response.send_message("An unexpected error occured.", ephemeral=True)
await interaction.response.send_message(
"An unexpected error occured.", ephemeral=True
)
raise error

async def on_submit(self: Self, interaction: discord.Interaction) -> None:
Expand All @@ -143,7 +149,11 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None:
use_email=True,
)

await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services)
await handle_submit(
report=report,
interaction=interaction,
dragonfly_services=self.bot.dragonfly_services,
)


class ConfirmReportModal(discord.ui.Modal):
Expand Down Expand Up @@ -175,14 +185,20 @@ def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None:

super().__init__()

async def on_error(self: Self, interaction: discord.Interaction, error: Exception) -> None:
async def on_error(
self: Self, interaction: discord.Interaction, error: Exception
) -> None:
"""Handle errors that occur in the modal."""
if isinstance(error, aiohttp.ClientResponseError):
message = f"Error from upstream: {error.status}\n```{error.message}```\nRetry using email instead?"
view = ReportMethodSwitchConfirmationView(previous_modal=self)
return await interaction.response.send_message(message, view=view, ephemeral=True)
return await interaction.response.send_message(
message, view=view, ephemeral=True
)

await interaction.response.send_message("An unexpected error occured.", ephemeral=True)
await interaction.response.send_message(
"An unexpected error occured.", ephemeral=True
)
raise error

async def on_submit(self: Self, interaction: discord.Interaction) -> None:
Expand All @@ -196,7 +212,11 @@ async def on_submit(self: Self, interaction: discord.Interaction) -> None:
use_email=False,
)

await handle_submit(report=report, interaction=interaction, dragonfly_services=self.bot.dragonfly_services)
await handle_submit(
report=report,
interaction=interaction,
dragonfly_services=self.bot.dragonfly_services,
)


class ReportMethodSwitchConfirmationView(discord.ui.View):
Expand All @@ -206,14 +226,18 @@ class ReportMethodSwitchConfirmationView(discord.ui.View):
user if they want to switch to another method of sending reports.
"""

def __init__(self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal) -> None:
def __init__(
self: Self, previous_modal: ConfirmReportModal | ConfirmEmailReportModal
) -> None:
super().__init__()
self.previous_modal = previous_modal
self.package = previous_modal.package
self.bot = previous_modal.bot

@discord.ui.button(label="Yes", style=discord.ButtonStyle.green)
async def confirm(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None:
async def confirm(
self: Self, interaction: discord.Interaction, _button: discord.ui.Button
) -> None:
"""Confirm button callback."""
if isinstance(self.previous_modal, ConfirmReportModal):
modal = ConfirmEmailReportModal(package=self.package, bot=self.bot)
Expand All @@ -226,7 +250,9 @@ async def confirm(self: Self, interaction: discord.Interaction, _button: discord
await interaction.edit_original_response(view=self)

@discord.ui.button(label="No, retry the operation", style=discord.ButtonStyle.red)
async def cancel(self: Self, interaction: discord.Interaction, _button: discord.ui.Button) -> None:
async def cancel(
self: Self, interaction: discord.Interaction, _button: discord.ui.Button
) -> None:
"""Cancel button callback."""
modal = type(self.previous_modal)(package=self.package, bot=self.bot)

Expand Down Expand Up @@ -261,6 +287,201 @@ async def report(self: Self, interaction: discord.Interaction, button: discord.u
await interaction.edit_original_response(view=self)


class NoteModal(discord.ui.Modal, title="Add a note"):
"""A modal that allows users to add a note to a package"""

_interaction: discord.Interaction | None = None
note_content = discord.ui.TextInput(
label="Content",
placeholder="Enter the note content here",
min_length=1,
max_length=1000, # Don't want to overfill the embed
)

def __init__(self, embed: discord.Embed, view: discord.ui.View):
super().__init__()

self.embed = embed
self.view = view

async def on_submit(self, interaction: discord.Interaction) -> None:
if not interaction.response.is_done():
await interaction.response.defer()
self._interaction = interaction

content = f"{self.note_content.value}{interaction.user.mention}"

# We need to check what fields the embed has to determine where to add the note
# If the embed has no fields, we add the note and return
# Otherwise, we need to make sure the note is added after the event log
# This involves clearing the fields and re-adding them in the correct order
# Which is why we save the event log in a variable

match len(self.embed.fields):
case 0: # Package is awaiting triage, no notes or event log
notes = [content]
event_log = None
case 1: # Package either has notes or event log
if self.embed.fields[0].name == "Notes":
notes = [self.embed.fields[0].value, content]
else:
event_log = self.embed.fields[0].value
notes = [content]
self.embed.clear_fields()
case 2: # Package has both notes and event log
if self.embed.fields[0].name == "Notes":
notes = [self.embed.fields[0].value, content]
event_log = self.embed.fields[1].value
else:
notes = [self.embed.fields[1].value, content]
event_log = self.embed.fields[0].value
self.embed.clear_fields()

self.embed.add_field(name="Notes", value="\n".join(notes), inline=False)

if event_log:
self.embed.add_field(name="Event log", value=event_log, inline=False)

await interaction.message.edit(embed=self.embed, view=self.view)

async def on_error(
self, interaction: discord.Interaction, error: Exception
) -> None:

await interaction.response.send_message(
"An unexpected error occured.", ephemeral=True
)
raise error

@property
def interaction(self) -> discord.Interaction | None:
return self._interaction


class MalwareView(discord.ui.View):
"""View for the malware triage system"""

message: discord.Message | None = None

def __init__(
self: Self, embed: discord.Embed, bot: Bot, payload: PackageScanResult
) -> None:
self.embed = embed
self.bot = bot
self.payload = payload
self.event_log = []

super().__init__()

async def enable_button(self, button_label: str) -> None:
for button in self.children:
if button.label == button_label:
button.disabled = False

async def add_event(self, message: str) -> None:
# Much like earlier, we need to check the fields of the embed to determine where to add the event log
match len(self.embed.fields):
case 0:
pass
case 1:
if self.embed.fields[0].name == "Event log":
self.embed.clear_fields()
case 2:
if self.embed.fields[0].name == "Event log":
self.embed.clear_fields()
elif self.embed.fields[1].name == "Event log":
self.embed.remove_field(1)

self.event_log.append(
message
) # For future reference, we save the event log in a variable
self.embed.add_field(
name="Event log", value="\n".join(self.event_log), inline=False
)

async def update_status(self, status: str) -> None:
self.embed.set_footer(text=status)

def get_timestamp(
self,
) -> (
int
): # This function returns the current timestamp in Discord's timestamp format
return f"<t:{int(datetime.now().timestamp())}:R>"

@discord.ui.button(label="Report", style=discord.ButtonStyle.red)
async def report(self: Self, interaction: discord.Interaction, button: discord.ui.Button) -> None: # type: ignore[type-arg]
"""Report a package."""
modal = ConfirmReportModal(package=self.payload, bot=self.bot)
await interaction.response.send_modal(modal)

timed_out = await modal.wait()
if not timed_out:
button.disabled = True
await interaction.edit_original_response(view=self)

@discord.ui.button(
label="Report",
style=discord.ButtonStyle.red,
)
async def report(
self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView]
) -> None:
await self.enable_button("Approve")
await self.add_event(
f"Reported by {interaction.user.mention}{self.get_timestamp()}"
)
await self.update_status("Flagged as malicious")

self.embed.color = discord.Color.red()

modal = ConfirmReportModal(package=self.payload, bot=self.bot)
await interaction.response.send_modal(modal)

timed_out = await modal.wait()
if not timed_out:
button.disabled = True
await interaction.edit_original_response(view=self, embed=self.embed)

@discord.ui.button(
label="Approve",
style=discord.ButtonStyle.green,
)
async def approve(
self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView]
) -> None:
await self.enable_button("Report")
await self.add_event(
f"Approved by {interaction.user.mention}{self.get_timestamp()}"
)
await self.update_status("Flagged as benign")

button.disabled = True

self.embed.color = discord.Color.green()
await interaction.response.edit_message(view=self, embed=self.embed)

@discord.ui.button(
label="Add note",
style=discord.ButtonStyle.grey,
)
async def add_note(
self, interaction: discord.Interaction, button: discord.ui.Button[MalwareView]
) -> None:
await interaction.response.send_modal(NoteModal(embed=self.embed, view=self))

async def on_error(
self,
interaction: discord.Interaction[discord.Client],
error: Exception,
) -> None:

await interaction.response.send_message(
"An unexpected error occured.", ephemeral=True
)
raise error


def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord.Embed:
"""Build the embed that shows the results of a package scan."""
condition = scan_result.score >= DragonflyConfig.threshold
Expand All @@ -287,7 +508,31 @@ def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord.
return embed


def _build_all_packages_scanned_embed(scan_results: list[PackageScanResult]) -> discord.Embed:
def _build_package_scan_result_triage_embed(
scan_result: PackageScanResult,
) -> discord.Embed:
"""Build the embed for the malware triage system"""

embed = discord.Embed(
title="View on Inspector",
description="\n".join(scan_result.rules),
url=scan_result.inspector_url,
color=discord.Color.orange(),
timestamp=datetime.now(UTC),
)
embed.set_author(
name=f"{scan_result.name}@{scan_result.version}",
url=f"https://pypi.org/project/{scan_result.name}/{scan_result.version}",
icon_url="https://seeklogo.com/images/P/pypi-logo-5B953CE804-seeklogo.com.png",
)
embed.set_footer(text="Awaiting triage")

return embed


def _build_all_packages_scanned_embed(
scan_results: list[PackageScanResult],
) -> discord.Embed:
"""Build the embed that shows a list of all packages scanned."""
if scan_results:
description = "\n".join(map(str, scan_results))
Expand All @@ -307,12 +552,22 @@ async def run(
scan_results = await bot.dragonfly_services.get_scanned_packages(since=since)
for result in scan_results:
if result.score >= score:
"""
embed = _build_package_scan_result_embed(result)
await alerts_channel.send(
f"<@&{DragonflyConfig.alerts_role_id}>",
embed=embed,
view=ReportView(bot, result),
)
"""
embed = _build_package_scan_result_triage_embed(result)
view = MalwareView(embed=embed, bot=bot, payload=result)

view.message = await alerts_channel.send(
f"<@&{DragonflyConfig.alerts_role_id}>",
embed=embed,
view=view,
)

await logs_channel.send(embed=_build_all_packages_scanned_embed(scan_results))

Expand All @@ -330,7 +585,9 @@ def __init__(self: Self, bot: Bot) -> None:
@commands.hybrid_command(name="username") # type: ignore [arg-type]
async def get_username_command(self, ctx: commands.Context[Bot]) -> None:
"""Get the username of the currently logged in user to the PyPI Observation API."""
async with ctx.bot.http_session.get(DragonflyConfig.reporter_url + "/echo") as res:
async with ctx.bot.http_session.get(
DragonflyConfig.reporter_url + "/echo"
) as res:
json = await res.json()
username = json["username"]

Expand Down Expand Up @@ -392,12 +649,16 @@ async def stop(self: Self, ctx: commands.Context, force: bool = False) -> None:
@discord.app_commands.command(name="lookup", description="Scans a package")
async def lookup(self: Self, interaction: discord.Interaction, name: str, version: str | None = None) -> None: # type: ignore[type-arg]
"""Pull the scan results for a package."""
scan_results = await self.bot.dragonfly_services.get_scanned_packages(name=name, version=version)
scan_results = await self.bot.dragonfly_services.get_scanned_packages(
name=name, version=version
)
if scan_results:
embed = _build_package_scan_result_embed(scan_results[0])
await interaction.response.send_message(embed=embed)
else:
await interaction.response.send_message("No entries were found with the specified filters.")
await interaction.response.send_message(
"No entries were found with the specified filters."
)

@commands.group()
async def threshold(self: Self, ctx: commands.Context) -> None: # type: ignore[type-arg]
Expand Down

0 comments on commit 59dd2a8

Please sign in to comment.