diff --git a/src/bot/__main__.py b/src/bot/__main__.py index 2cd4468..9e92a9e 100644 --- a/src/bot/__main__.py +++ b/src/bot/__main__.py @@ -27,7 +27,7 @@ def get_prefix(bot_: Bot, message_: discord.Message) -> Callable[[Bot, discord.M async def main() -> None: """Run the bot.""" - async with ClientSession(headers={"Content-Type": "application/json"}, timeout=ClientTimeout(total=10)) as session: + async with ClientSession(headers={"Content-Type": "application/json"}, timeout=ClientTimeout(total=30)) as session: dragonfly_services = DragonflyServices( session=session, base_url=constants.Dragonfly.base_url, diff --git a/src/bot/constants.py b/src/bot/constants.py index 44312d5..052928e 100644 --- a/src/bot/constants.py +++ b/src/bot/constants.py @@ -33,6 +33,18 @@ class _Miscellaneous(EnvConfig): Miscellaneous = _Miscellaneous() + +class _ThreatIntelFeed(EnvConfig, env_prefix="tif_"): + """Threat Intelligence Feed Configuration.""" + + repository: str = "pypi/pypi-observation-reports-private" + interval: int = 60 * 60 # 1 hour + access_token: str = "" + channel_id: int = 1121471544355455058 + + +ThreatIntelFeed = _ThreatIntelFeed() + FILE_LOGS = Miscellaneous.file_logs DEBUG_MODE = Miscellaneous.debug diff --git a/src/bot/dragonfly_services.py b/src/bot/dragonfly_services.py index 5d2e6a4..259b80b 100644 --- a/src/bot/dragonfly_services.py +++ b/src/bot/dragonfly_services.py @@ -7,6 +7,7 @@ from typing import Any, Self from aiohttp import ClientSession +from pydantic import BaseModel class ScanStatus(Enum): @@ -18,41 +19,29 @@ class ScanStatus(Enum): FAILED = "failed" -@dataclass -class PackageScanResult: - """A package scan result.""" +class Package(BaseModel): + """Model representing a package queried from the database.""" - status: ScanStatus - inspector_url: str - queued_at: datetime + scan_id: str + name: str + version: str + status: ScanStatus | None + score: int | None + inspector_url: str | None + rules: list[str] = [] + download_urls: list[str] = [] + queued_at: datetime | None + queued_by: str | None + reported_at: datetime | None + reported_by: str | None pending_at: datetime | None + pending_by: str | None finished_at: datetime | None - reported_at: datetime | None - version: str - name: str - package_id: str - rules: list[str] - score: int - - @classmethod - def from_dict(cls: type[Self], data: dict) -> Self: # type: ignore[type-arg] - """Create a PackageScanResult from a dictionary.""" - return cls( - status=ScanStatus(data["status"]), - inspector_url=data["inspector_url"], - queued_at=datetime.fromisoformat(data["queued_at"]), - pending_at=datetime.fromisoformat(p) if (p := data["pending_at"]) else None, - finished_at=datetime.fromisoformat(p) if (p := data["finished_at"]) else None, - reported_at=datetime.fromisoformat(p) if (p := data["reported_at"]) else None, - version=data["version"], - name=data["name"], - package_id=data["scan_id"], - rules=[d["name"] for d in data["rules"]], - score=int(data["score"]), - ) - - def __str__(self: Self) -> str: - """Return a string representation of the package scan result.""" + finished_by: str | None + commit_hash: str | None + + def __str__(self) -> str: + """Return package name and version.""" return f"{self.name} {self.version}" @@ -146,7 +135,7 @@ async def get_scanned_packages( name: str | None = None, version: str | None = None, since: datetime | None = None, - ) -> list[PackageScanResult]: + ) -> list[Package]: """Get a list of scanned packages.""" params = {} if name: @@ -159,7 +148,7 @@ async def get_scanned_packages( params["since"] = int(since.timestamp()) # type: ignore[assignment] data = await self.make_request("GET", "/package", params=params) - return [PackageScanResult.from_dict(dct) for dct in data] + return list(map(Package.model_validate, data)) async def report_package( self: Self, diff --git a/src/bot/exts/audit.py b/src/bot/exts/audit.py index 39f83f9..a47b635 100644 --- a/src/bot/exts/audit.py +++ b/src/bot/exts/audit.py @@ -10,7 +10,7 @@ from discord.ext import commands from bot.bot import Bot -from bot.dragonfly_services import PackageScanResult +from bot.dragonfly_services import Package class PaginatorView(ui.View): @@ -20,7 +20,7 @@ def __init__( self: Self, *, member: discord.Member | discord.User, - packages: list[PackageScanResult], + packages: list[Package], per: int = 15, ) -> None: """Initialize the paginator view.""" @@ -70,7 +70,7 @@ async def interaction_check(self: Self, interaction: discord.Interaction) -> boo await interaction.response.send_message("This paginator is not for you!", ephemeral=True) return False - def _build_embed(self: Self, packages: list[PackageScanResult], page: int, total: int) -> discord.Embed: + def _build_embed(self: Self, packages: list[Package], page: int, total: int) -> discord.Embed: """Build an embed for the given packages.""" embed = discord.Embed( title="Package Audit", diff --git a/src/bot/exts/dragonfly/dragonfly.py b/src/bot/exts/dragonfly/dragonfly.py index 45ce50b..ad9152c 100644 --- a/src/bot/exts/dragonfly/dragonfly.py +++ b/src/bot/exts/dragonfly/dragonfly.py @@ -9,10 +9,11 @@ import discord import sentry_sdk from discord.ext import commands, tasks +from discord.utils import format_dt from bot.bot import Bot from bot.constants import Channels, DragonflyConfig, Roles -from bot.dragonfly_services import DragonflyServices, PackageReport, PackageScanResult +from bot.dragonfly_services import DragonflyServices, Package, PackageReport log = getLogger(__name__) log.setLevel(logging.INFO) @@ -107,7 +108,7 @@ class ConfirmEmailReportModal(discord.ui.Modal): style=discord.TextStyle.short, ) - def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: + def __init__(self: Self, *, package: Package, bot: Bot) -> None: """Initialize the modal.""" self.package = package self.bot = bot @@ -163,7 +164,7 @@ class ConfirmReportModal(discord.ui.Modal): style=discord.TextStyle.short, ) - def __init__(self: Self, *, package: PackageScanResult, bot: Bot) -> None: + def __init__(self: Self, *, package: Package, bot: Bot) -> None: """Initialize the modal.""" self.package = package self.bot = bot @@ -244,7 +245,7 @@ def disable_all(self: Self) -> None: class ReportView(discord.ui.View): """Report view.""" - def __init__(self: Self, bot: Bot, payload: PackageScanResult) -> None: + def __init__(self: Self, bot: Bot, payload: Package) -> None: self.bot = bot self.payload = payload super().__init__(timeout=None) @@ -261,9 +262,203 @@ async def report(self: Self, interaction: discord.Interaction, button: discord.u await interaction.edit_original_response(view=self) -def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord.Embed: +class NoteModal(discord.ui.Modal, title="Add a note"): + """A modal that allows users to add a note to a package.""" + + _interaction: discord.Interaction | None = None + note_content = discord.ui.TextInput( + label="Content", + placeholder="Enter the note content here", + min_length=1, + max_length=1000, # Don't want to overfill the embed + ) + + def __init__(self, embed: discord.Embed, view: discord.ui.View) -> None: + super().__init__() + + self.embed = embed + self.view = view + + async def on_submit(self, interaction: discord.Interaction) -> None: + """Modal submit callback.""" + if not interaction.response.is_done(): + await interaction.response.defer() + self._interaction = interaction + + content = f"{self.note_content.value} • {interaction.user.mention}" + + # We need to check what fields the embed has to determine where to add the note + # If the embed has no fields, we add the note and return + # Otherwise, we need to make sure the note is added after the event log + # This involves clearing the fields and re-adding them in the correct order + # Which is why we save the event log in a variable + + match len(self.embed.fields): + case 0: # Package is awaiting triage, no notes or event log + notes = [content] + event_log = None + case 1: # Package either has notes or event log + if self.embed.fields[0].name == "Notes": + notes = [self.embed.fields[0].value, content] + else: + event_log = self.embed.fields[0].value + notes = [content] + self.embed.clear_fields() + case 2: # Package has both notes and event log + if self.embed.fields[0].name == "Notes": + notes = [self.embed.fields[0].value, content] + event_log = self.embed.fields[1].value + else: + notes = [self.embed.fields[1].value, content] + event_log = self.embed.fields[0].value + self.embed.clear_fields() + + self.embed.add_field(name="Notes", value="\n".join(notes), inline=False) + + if event_log: + self.embed.add_field(name="Event log", value=event_log, inline=False) + + await interaction.message.edit(embed=self.embed, view=self.view) + + async def on_error( + self, + interaction: discord.Interaction, + error: Exception, + ) -> None: + """Handle errors that occur in the modal.""" + await interaction.response.send_message( + "An unexpected error occured.", + ephemeral=True, + ) + raise error + + +class MalwareView(discord.ui.View): + """View for the malware triage system.""" + + message: discord.Message | None = None + + def __init__( + self: Self, + embed: discord.Embed, + bot: Bot, + payload: Package, + ) -> None: + self.embed = embed + self.bot = bot + self.payload = payload + self.event_log = [] + + super().__init__() + + async def add_event(self, message: str) -> None: + """Add an event to the event log.""" + # Much like earlier, we need to check the fields of the embed to determine where to add the event log + match len(self.embed.fields): + case 0: + pass + case 1: + if self.embed.fields[0].name == "Event log": + self.embed.clear_fields() + case 2: + if self.embed.fields[0].name == "Event log": + self.embed.clear_fields() + elif self.embed.fields[1].name == "Event log": + self.embed.remove_field(1) + + self.event_log.append( + message, + ) # For future reference, we save the event log in a variable + self.embed.add_field( + name="Event log", + value="\n".join(self.event_log), + inline=False, + ) + + async def update_status(self, status: str) -> None: + """Update the status of the package in the embed.""" + self.embed.set_footer(text=status) + + def get_timestamp(self) -> str: + """Return the current timestamp, formatted in Discord's relative style.""" + return format_dt(datetime.now(UTC), style="R") + + @discord.ui.button( + label="Report", + style=discord.ButtonStyle.red, + ) + async def report( + self, + interaction: discord.Interaction, + button: discord.ui.Button, + ) -> None: + """Report package and update the embed.""" + self.approve.disabled = False + await self.add_event( + f"Reported by {interaction.user.mention} • {self.get_timestamp()}", + ) + await self.update_status("Flagged as malicious") + + self.embed.color = discord.Color.red() + + modal = ConfirmReportModal(package=self.payload, bot=self.bot) + await interaction.response.send_modal(modal) + + timed_out = await modal.wait() + if not timed_out: + button.disabled = True + await interaction.edit_original_response(view=self, embed=self.embed) + + @discord.ui.button( + label="Approve", + style=discord.ButtonStyle.green, + ) + async def approve( + self, + interaction: discord.Interaction, + button: discord.ui.Button, + ) -> None: + """Approve package and update the embed.""" + self.report.disabled = False + await self.add_event( + f"Approved by {interaction.user.mention} • {self.get_timestamp()}", + ) + await self.update_status("Flagged as benign") + + button.disabled = True + + self.embed.color = discord.Color.green() + await interaction.response.edit_message(view=self, embed=self.embed) + + @discord.ui.button( + label="Add note", + style=discord.ButtonStyle.grey, + ) + async def add_note( + self, + interaction: discord.Interaction, + _button: discord.ui.Button, + ) -> None: + """Add note to the embed.""" + await interaction.response.send_modal(NoteModal(embed=self.embed, view=self)) + + async def on_error( + self, + interaction: discord.Interaction[discord.Client], + error: Exception, + _item: discord.ui.Item, + ) -> None: + """Handle errors that occur in the view.""" + await interaction.response.send_message( + "An unexpected error occured.", + ephemeral=True, + ) + raise error + + +def _build_package_scan_result_embed(scan_result: Package) -> discord.Embed: """Build the embed that shows the results of a package scan.""" - condition = scan_result.score >= DragonflyConfig.threshold + condition = (scan_result.score or 0) >= DragonflyConfig.threshold title, color = ("Malicious", 0xF70606) if condition else ("Benign", 0x4CBB17) embed = discord.Embed( @@ -287,7 +482,28 @@ def _build_package_scan_result_embed(scan_result: PackageScanResult) -> discord. return embed -def _build_all_packages_scanned_embed(scan_results: list[PackageScanResult]) -> discord.Embed: +def _build_package_scan_result_triage_embed( + scan_result: Package, +) -> discord.Embed: + """Build the embed for the malware triage system.""" + embed = discord.Embed( + title="View on Inspector", + description="\n".join(scan_result.rules), + url=scan_result.inspector_url, + color=discord.Color.orange(), + timestamp=datetime.now(UTC), + ) + embed.set_author( + name=f"{scan_result.name}@{scan_result.version}", + url=f"https://pypi.org/project/{scan_result.name}/{scan_result.version}", + icon_url="https://seeklogo.com/images/P/pypi-logo-5B953CE804-seeklogo.com.png", + ) + embed.set_footer(text="Awaiting triage") + + return embed + + +def _build_all_packages_scanned_embed(scan_results: list[Package]) -> discord.Embed: """Build the embed that shows a list of all packages scanned.""" if scan_results: description = "\n".join(map(str, scan_results)) @@ -307,11 +523,13 @@ async def run( scan_results = await bot.dragonfly_services.get_scanned_packages(since=since) for result in scan_results: if result.score >= score: - embed = _build_package_scan_result_embed(result) - await alerts_channel.send( + embed = _build_package_scan_result_triage_embed(result) + view = MalwareView(embed=embed, bot=bot, payload=result) + + view.message = await alerts_channel.send( f"<@&{DragonflyConfig.alerts_role_id}>", embed=embed, - view=ReportView(bot, result), + view=view, ) await logs_channel.send(embed=_build_all_packages_scanned_embed(scan_results)) diff --git a/src/bot/exts/dragonfly/threat_intel_feed.py b/src/bot/exts/dragonfly/threat_intel_feed.py new file mode 100644 index 0000000..9f44d7f --- /dev/null +++ b/src/bot/exts/dragonfly/threat_intel_feed.py @@ -0,0 +1,166 @@ +"""Threat Intelligence Feed Cog.""" + +import json +import logging +import re +from io import BytesIO +from logging import getLogger +from typing import Any +from zipfile import ZipFile + +import aiohttp +import discord +from discord.ext import commands, tasks + +from bot import constants +from bot.bot import Bot +from bot.dragonfly_services import Package + +log = getLogger(__name__) +log.setLevel(logging.INFO) + +_p = re.compile(r"https://inspector.pypi.io/project/(?P\w+)/(?P[\w.]+)/.*") + + +def build_github_link_from_path(path: str) -> str: + """Build a GitHub link to the given path.""" + segments = path.split("/") + path = "/".join(segments[1:]) + + return f"https://github.com/{constants.ThreatIntelFeed.repository}/blob/main/{path}" + + +def parse_package_info_from_inspector_url(inspector_url: str) -> tuple[str, str] | None: + """Return a tuple of package name and version, parsed from the inspector URL. None if it couldn't be parsed.""" + if g := _p.match(inspector_url): + name = g.group("name") + version = g.group("version") + + return name, version + + return None + + +def search(d: dict, key: Any) -> Any | None: # noqa: ANN401 - we can't know the type of the dict ahead of time + """Recursively search for the first occurence of a key in a dict. None if not found.""" + for k, v in d.items(): + if k == key: + return v + + if isinstance(v, dict) and (val := search(v, key)): + return val + + return None + + +def build_embed(package: Package, path: str, inspector_url: str) -> discord.Embed: + """Return the embed to be sent in the threat intelligence feed.""" + if package.reported_at: + ts = discord.utils.format_dt(package.reported_at, style="F") + description = f"We already reported this package at {ts}" + color = discord.Color.green() + else: + description = f"We didn't catch this package! Here are our matched rules: ```{', '.join(package.rules)}```" + color = discord.Colour.red() + + embed = discord.Embed( + title=f"New Report: {package.name} v{package.version}", + description=description, + color=color, + url=build_github_link_from_path(path), + ) + + embed.add_field(name="Inspector URL", value=f"[Inspector URL]({inspector_url})") + + return embed + + +def build_package_not_found_embed(name: str, version: str, path: str) -> discord.Embed: + """Return the embed for when a report was filed for a package which we don't have records for.""" + return discord.Embed( + title="Package not found!", + description=( + f"A report was filed for {name} v{version}, " + "however we don't have any records for this package in our database. " + "This means that we are missing packages, please investigate this!" + ), + color=discord.Color.red(), + url=build_github_link_from_path(path), + ) + + +async def fetch_zipfile(http_client: aiohttp.ClientSession) -> ZipFile: + """Download the source zipfile from GitHub for the feed source repository.""" + url = f"https://api.github.com/repos/{constants.ThreatIntelFeed.repository}/zipball" + headers = {"Authorization": f"Bearer {constants.ThreatIntelFeed.access_token}"} + + async with http_client.get(url, headers=headers) as res: + res.raise_for_status() + b = await res.content.read() + + buffer = BytesIO(b) + return ZipFile(buffer) + + +class ThreatIntelFeed(commands.Cog): + """Threat Intelligence Feed Cog.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.reports_seen: set[str] = set() + + @tasks.loop(seconds=constants.ThreatIntelFeed.interval) + async def watcher(self) -> None: + """Watch the GitHub repository for changes.""" + zipfile = await fetch_zipfile(self.bot.http_session) + paths = {path for path in zipfile.namelist() if path.endswith(".json")} + + channel = self.bot.get_channel(constants.ThreatIntelFeed.channel_id) + if not isinstance(channel, discord.abc.Messageable): + log.error("Threat intel feed channel is not messageable") + return + + # The first time around, just add all the reports to our "seen reports" set + if len(self.reports_seen) == 0: + self.reports_seen |= paths + return + + for path in paths: + if path in self.reports_seen: + continue + + content = json.loads(zipfile.read(path).decode()) + inspector_url: str | None = search(content, "inspector_url") + if not inspector_url: + log.error("Inspector URL not found in %s, skipping", path) + continue + + match parse_package_info_from_inspector_url(inspector_url): + case name, version: + results = await self.bot.dragonfly_services.get_scanned_packages(name=name, version=version) + package = results[0] if results else None + + if package: + embed = build_embed(package, path, inspector_url) + else: + embed = build_package_not_found_embed(name, version, path) + + await channel.send(embed=embed) + + case None: + log.error('Unable to parse inspector URL: "%s" in %s, skipping', inspector_url, path) + continue + + @watcher.before_loop + async def before_watcher(self) -> None: + """Before first task run hook.""" + await self.bot.wait_until_ready() + + +async def setup(bot: Bot) -> None: + """Extension setup.""" + cog = ThreatIntelFeed(bot) + task = cog.watcher + if not task.is_running: + task.start() + await bot.add_cog(cog)