diff --git a/appservice/main.py b/appservice/main.py index 4bfd5a2..0b40183 100644 --- a/appservice/main.py +++ b/appservice/main.py @@ -6,7 +6,7 @@ import sys import threading import urllib.parse -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import markdown import urllib3 @@ -18,6 +18,7 @@ from db import DataBase from errors import RequestError from gateway import Gateway +from message_parser import MatrixParser, escape_markdown from misc import dict_cls, except_deleted, hash_str @@ -86,6 +87,17 @@ def on_member(self, event: matrix.Event) -> None: self.logger.info(f"Joining direct message room '{event.room_id}'.") self.join_room(event.room_id) + def append_replied_to_msg(self, message: matrix.Event) -> str: + def escape_urls(message_: str): + return re.sub(r"(?>", message_) + if message.reply and message.reply.get("event_id"): + replied_to_body: Optional[matrix.Event] = except_deleted(self.get_event)(message.reply["event_id"], message.room_id) + if replied_to_body and not replied_to_body.redacted_because: + return "> " + escape_urls(self.parse_message(replied_to_body, limit=600, generate_link=False).replace("\n", "\n> ").strip()) + "\n" + else: + return "> šŸ—‘ļøšŸ’¬\n" # I really don't want to add translatable strings to this project + return "" + def on_message(self, message: matrix.Event) -> None: if ( message.sender.startswith((f"@{self.format}", self.user_id)) @@ -107,37 +119,79 @@ def on_message(self, message: matrix.Event) -> None: channel_id, self.discord.webhook_name ) + # Let's take a few scenarios that can happen. We should handle at least 2 special message cases: replies and edits + # Replies should ask for replied to message event, parse that event, limit output to maybe like 500 characters + # and prepend it to main message in form of a quote, we can't just use Discord's reply because Discord being dumb + # https://github.com/discord/discord-api-docs/discussions/3282 + # Edits should look at previously edited message, if it was a reply they need to handle all that reply logic again + # However edits lose replied to field so we have to fetch original message (wooho?) and get it from that instead + + content = "" + if message.relates_to and message.reltype == "m.replace": with Cache.lock: message_id = Cache.cache["m_messages"].get(message.relates_to) - # TODO validate if the original author sent the edit. + original_message: Optional[matrix.Event] = except_deleted(self.get_event)(message.relates_to, message.room_id) if not message_id or not message.new_body: return - message.new_body = self.process_message(message) + if original_message: + if message.sender != original_message.sender: + return + content += self.append_replied_to_msg(original_message) + # If new body has formatted form, use that + message.body = message.new_body.get("body", "") + message.formatted_body = message.new_body.get("formatted_body", "") + content += self.parse_message(message) except_deleted(self.discord.edit_webhook)( - message.new_body, message_id, webhook + content[:discord.MESSAGE_LIMIT], message_id, webhook ) else: - message.body = ( + content += self.append_replied_to_msg(message) + content += ( f"`{message.body}`: {self.mxc_url(message.attachment)}" if message.attachment - else self.process_message(message) + else self.parse_message(message) ) - + if not content or content.isspace(): + return message_id = self.discord.send_webhook( webhook, self.mxc_url(author.avatar_url) if author.avatar_url else None, - message.body, + content[:discord.MESSAGE_LIMIT], author.display_name if author.display_name else message.sender, ).id with Cache.lock: Cache.cache["m_messages"][message.id] = message_id + @staticmethod + def create_msg_link(room_id: str, event: str) -> str: + return f"[[ā€¦]]()" + + def parse_message(self, message: matrix.Event, limit: int = discord.MESSAGE_LIMIT, generate_link: bool = True): + if message.formatted_body: + msg_link = self.create_msg_link(message.room_id, message.id) if generate_link else "" + parser = MatrixParser(self.db, self.mention_regex(False, True), self.mxc_url, limit=limit-len(msg_link)) + try: + parser.feed(message.formatted_body) + except StopIteration: + self.logger.debug("Message has exceeded maximum allowed character limit, processing what we already have") + message.body = parser.message + # Create a link to message for Discord side to spread the word about Matrix superior character limit + if generate_link: + message.body += msg_link + else: + message.body = parser.message + else: + # if we escape : in protocol prefix of a link is going to be plaintext on Discord, we don't want that + # but we still have to escape : for emojis so this is a measure for that + message.body = escape_markdown(message.body).replace("\\://", "://") + return message.body + def on_redaction(self, event: matrix.Event) -> None: with Cache.lock: message_id = Cache.cache["m_messages"].get(event.redacts) @@ -345,53 +399,6 @@ def mention_regex(self, encode: bool, id_as_group: bool) -> str: return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}" - def process_message(self, event: matrix.Event) -> str: - message = event.new_body if event.new_body else event.body - - emotes = re.findall(r":(\w*):", message) - - mentions = list( - re.finditer( - self.mention_regex(encode=False, id_as_group=True), - event.formatted_body, - ) - ) - # For clients that properly encode mentions. - # 'https://matrix.to/#/%40_discord_...%3Adomain.tld' - mentions.extend( - re.finditer( - self.mention_regex(encode=True, id_as_group=True), - event.formatted_body, - ) - ) - - with Cache.lock: - for emote in set(emotes): - emote_ = Cache.cache["d_emotes"].get(emote) - if emote_: - message = message.replace(f":{emote}:", emote_) - - for mention in set(mentions): - # Unquote just in-case we matched an encoded username. - username = self.db.fetch_user( - urllib.parse.unquote(mention.group(0)) - ).get("username") - if username: - if mention.group(2): - # Replace mention with plain text for hashed users (webhooks) - message = message.replace(mention.group(0), f"@{username}") - else: - # Replace the 'mention' so that the user is tagged - # in the case of replies aswell. - # '> <@_discord_1234:localhost> Message' - for replace in (mention.group(0), username): - message = message.replace( - replace, f"<@{mention.group(1)}>" - ) - - # We trim the message later as emotes take up extra characters too. - return message[: discord.MESSAGE_LIMIT] - def upload_emote(self, emote_name: str, emote_id: str) -> None: # There won't be a race condition here, since only a unique # set of emotes are uploaded at a time. diff --git a/appservice/matrix.py b/appservice/matrix.py index ab10124..43dd903 100644 --- a/appservice/matrix.py +++ b/appservice/matrix.py @@ -20,9 +20,10 @@ def __init__(self, event: dict): self.room_id = event["room_id"] self.sender = event["sender"] self.state_key = event.get("state_key", "") - + self.redacted_because = event.get("redacted_because", {}) rel = content.get("m.relates_to", {}) self.relates_to = rel.get("event_id") self.reltype = rel.get("rel_type") - self.new_body = content.get("m.new_content", {}).get("body", "") + self.reply: dict = rel.get("m.in_reply_to") + self.new_body = content.get("m.new_content", {}) diff --git a/appservice/message_parser.py b/appservice/message_parser.py new file mode 100644 index 0000000..0bbdb64 --- /dev/null +++ b/appservice/message_parser.py @@ -0,0 +1,224 @@ +import re +import logging +from html.parser import HTMLParser +from typing import Optional, Tuple, List, Callable + +from db import DataBase +from cache import Cache + +htmltomarkdown = {"strong": "**", "ins": "__", "u": "__", "b": "**", "em": "*", "i": "*", "del": "~~", "strike": "~~", "s": "~~"} +headers = {"h1": "***__", "h2": "**__", "h3": "**", "h4": "__", "h5": "*", "h6": ""} + +logger = logging.getLogger("message_parser") + + +def search_attr(attrs: List[Tuple[str, Optional[str]]], searched: str) -> Optional[str]: + for attr in attrs: + if attr[0] == searched: + return attr[1] or "" + return None + + +def escape_markdown(to_escape: str): + to_escape.replace("\\", "\\\\") + return re.sub(r"([`_*~:<>{}@|(])", r"\\\1", to_escape) + + +class Tags(object): + def __init__(self): + self.c_tags = [] + self.length = 0 + + @staticmethod + def _gauge_length(tag: str) -> int: + if tag in htmltomarkdown: + return len(htmltomarkdown.get(tag)) + elif tag == "spoiler": + return 2 + elif tag == "pre": + return 3 + elif tag == "code": + return 1 + return 0 + + def append(self, tag: str): + self.c_tags.append(tag) + self.length += self._gauge_length(tag) + + def pop(self) -> Optional[str]: + try: + last_tag = self.c_tags.pop() + self.length -= self._gauge_length(last_tag) + return last_tag + except IndexError: + return None + + def get_last(self) -> Optional[str]: + try: + return self.c_tags[-1] + except IndexError: + return None + + def get_size(self) -> int: + return self.length + + def __reversed__(self): + return iter(self.c_tags[::-1]) + + def __len__(self): + return len(self.c_tags) + + def __iter__(self): + return iter(self.c_tags) + + def __bool__(self): + return bool(self.c_tags) + + +class MatrixParser(HTMLParser): + def __init__(self, db: DataBase, mention_regex: str, mxc_img: Callable, limit: int = 0): + super().__init__() + self.message: str = "" + self.current_link: str = "" + self.tags: Tags = Tags() + self.list_num: int = 1 + self.db: DataBase = db + self.snowflake_regex: str = mention_regex + self.limit: int = limit + self.overflow: bool = False + self.mxc_to_img: Callable = mxc_img + + def search_for_feature(self, acceptable_features: Tuple[str, ...]) -> Optional[str]: + """Searches for certain feature in opened HTML tags for given text, if found returns the tag, if not returns None""" + for tag in reversed(self.tags): + if tag in acceptable_features: + return tag + return None + + def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]): + if "mx-reply" in self.tags: + return + self.tags.append(tag) + + if tag in htmltomarkdown: + self.expand_message(htmltomarkdown[tag]) + elif tag == "code": + if self.search_for_feature(("pre",)): + self.expand_message("```" + (search_attr(attrs, "class") or "")[9:] + "\n") + else: + self.expand_message("`") + elif tag == "span": + spoiler = search_attr(attrs, "data-mx-spoiler") + if spoiler is not None: + if spoiler: # Spoilers can have a reason https://github.com/matrix-org/matrix-doc/pull/2010 + self.expand_message(f"({spoiler})") + self.expand_message("||") + self.tags.append("spoiler") # Always after span tag + elif tag == "li": + list_type = self.search_for_feature(("ul", "ol")) + if list_type == "ol": + self.expand_message("\n{}. ".format(self.list_num)) + self.list_num += 1 + else: + self.expand_message("\nā€¢ ") + elif tag in ("br", "p"): + if not self.message.endswith('\n'): + self.expand_message("\n") + if self.search_for_feature(("blockquote",)): + self.expand_message("> ") + elif tag == "a": + self.parse_mentions(attrs) + elif tag == "mx-reply": # we handle replies separately for best effect + return + elif tag == "img": + if search_attr(attrs, "data-mx-emoticon") is not None: + emote_name = search_attr(attrs, "title") + if emote_name is None: + return + emote_ = Cache.cache["d_emotes"].get(emote_name.strip(":")) + if emote_: + self.expand_message(emote_) + else: + self.expand_message(emote_name) + else: + image_link = search_attr(attrs, "src") + if image_link and image_link.startswith("mxc://"): + self.expand_message(f"[{search_attr(attrs, 'title') or image_link}]({self.mxc_to_img(image_link)})") + elif tag in ("h1", "h2", "h3", "h4", "h5", "h6"): + if not self.message.endswith('\n'): + self.expand_message("\n") + self.expand_message(headers[tag]) + elif tag == "hr": + self.expand_message("\nā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€\n") + self.tags.pop() + + def parse_mentions(self, attrs): + self.current_link = search_attr(attrs, "href") + if self.current_link.startswith("https://matrix.to/#/"): + target = self.current_link[20:] + if target.startswith("@"): + self.expand_message(self.parse_user(target.split("?")[0])) + # Rooms will be handled by handle_data on data + + def parse_user(self, target: str): + if self.is_discord_user(target): + snowflake = re.search(re.compile(self.snowflake_regex), target).group(1) + if snowflake: + self.current_link = None # Meaning, skip adding text + return f"<@{snowflake}>" + else: + # Matrix user, not Discord appservice account + return "" + + def close_tags(self): + for tag in reversed(self.tags): + self.handle_endtag(tag) + + def expand_message(self, expansion: str): + # This calculation is not ideal. self.limit is further restricted by message link length, so if a message + # doesn't really go over the limit but message + message link does it will still treat it as out of the limit + if len(self.message) + self.tags.get_size() + len(expansion) > self.limit and self.overflow is False: + # Lets close all of the tags to make sure we don't have display errors + self.overflow = True + self.close_tags() + raise StopIteration + self.message += expansion + + def is_discord_user(self, target: str) -> bool: + return bool(self.db.fetch_user(target)) + + def handle_data(self, data): + if self.tags: + if self.tags.get_last() != "code": + data = escape_markdown(data.replace("\n", "")) + if "mx-reply" in self.tags: + return + if self.current_link: + self.expand_message(f"[{data}](<{self.current_link}>)") + self.current_link = "" + elif self.current_link is None: + self.current_link = "" + else: + self.expand_message(data) # strip new lines, they will be mostly handled by parser + + def handle_endtag(self, tag: str): + if "mx-reply" in self.tags and tag != "mx-reply": + return + if tag in htmltomarkdown: + self.expand_message(htmltomarkdown[tag]) + last_tag = self.tags.pop() + if last_tag is None: + logger.error("tried to pop {} from message tags but list is empty, current message {}".format(tag, self.message)) + return + if last_tag == "spoiler": + self.expand_message("||") + self.tags.pop() # guaranteed to be a span tag + if tag == "ol": + self.list_num = 1 + elif tag == "code": + if self.search_for_feature(("pre",)): + self.expand_message("\n```") + else: + self.expand_message("`") + elif tag in ("h1", "h2", "h3", "h4", "h5", "h6"): + self.expand_message(headers[tag][::-1])