diff --git a/easyllm/data/__init__.py b/easyllm/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easyllm/data/extractor/__init__.py b/easyllm/data/extractor/__init__.py new file mode 100644 index 0000000..3f25c04 --- /dev/null +++ b/easyllm/data/extractor/__init__.py @@ -0,0 +1 @@ +from easyllm.data.extractor.html_extractor import HtmlExtractor diff --git a/easyllm/data/extractor/html_extractor.py b/easyllm/data/extractor/html_extractor.py new file mode 100644 index 0000000..6a1d4a3 --- /dev/null +++ b/easyllm/data/extractor/html_extractor.py @@ -0,0 +1,23 @@ +# +from inscriptis import get_text +from inscriptis.css_profiles import CSS_PROFILES +from inscriptis.model.config import ParserConfig +from pydantic import BaseModel +from readability import Document + +INSCRIPTIS_CONFIG = ParserConfig(css=CSS_PROFILES["strict"]) + + +class HtmlExtractor(BaseModel): + """ + Desc: Extracts text from the HTML document using mozzilas readability and inscriptis. + """ + + name: str = "html_extractor" + min_doc_length: int = 25 + + def __call__(self, document: str) -> str: + parsed_doc = Document(document, min_text_length=self.min_doc_length) + clean_html = parsed_doc.summary(html_partial=True) + content = get_text(clean_html, INSCRIPTIS_CONFIG).strip() + return content diff --git a/easyllm/data/filters/__init__.py b/easyllm/data/filters/__init__.py new file mode 100644 index 0000000..d37d7b2 --- /dev/null +++ b/easyllm/data/filters/__init__.py @@ -0,0 +1,14 @@ +from easyllm.data.filters.bulletpoint_ratio import BulletpointRatioFilter +from easyllm.data.filters.common_word import CommonWordFilter +from easyllm.data.filters.digit_to_character import DigitToCharacter +from easyllm.data.filters.kenlm_ppl import PerplexityFilter +from easyllm.data.filters.length import LengthFilter +from easyllm.data.filters.longword import LongWordFilter +from easyllm.data.filters.n_gram import TopNGramsFilter +from easyllm.data.filters.non_alpha_numeric import NonAlphaNumericFilter +from easyllm.data.filters.parantheses_ration import ParenthesesRationFilter +from easyllm.data.filters.punctuation import EllipsisFilter, PunctuationFilter +from easyllm.data.filters.repeating import RepeatedLinesFilter, RepeatedParagraphFilter +from easyllm.data.filters.url_ratio import UrlRatioFilter +from easyllm.data.filters.whitespace_ration import WhitespaceRatioFilter +from easyllm.data.filters.words_to_symbol import SymbolToWordFilter diff --git a/easyllm/data/filters/bulletpoint_ratio.py b/easyllm/data/filters/bulletpoint_ratio.py new file mode 100644 index 0000000..566d554 --- /dev/null +++ b/easyllm/data/filters/bulletpoint_ratio.py @@ -0,0 +1,42 @@ +from typing import List + +from pydantic import BaseModel + + +class BulletpointRatioFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If more than 90% of the document are bulletpoints then remove + """ + + name: str = "bulletpoint_ratio" + potential_bullet_points: List[str] = [ + "•", + "‣", + "⁃", + "⁌", + "⁍", + "∙", + "○", + "●", + "◘", + "◦", + "⦾", + "⦿", + "-", + ] + remove_percentage: float = 0.9 + + def __call__(self, text): + # split text into lines + lines = text.split("\n") + num_bullet_points = 0 + for line in lines: + # check if the line is a bullet point + if line.startswith(tuple(self.potential_bullet_points)): + num_bullet_points += 1 + # check if the ratio of bullet points to lines is greater than the remove percentage + if num_bullet_points / len(lines) > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/common_word.py b/easyllm/data/filters/common_word.py new file mode 100644 index 0000000..f562020 --- /dev/null +++ b/easyllm/data/filters/common_word.py @@ -0,0 +1,29 @@ +from typing import List + +from pydantic import BaseModel + +COMMON_WORDS_EN = ["the", "be", "to", "of", "and", "that", "have", "with", "this"] +COMMON_WORDS_DE = ["der", "die", "das", "er" "sein", "zu", "ist", "war", "von", "und", "haben", "mit"] + + +class CommonWordFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: Makes sure that the document contains at least 2 common words if not remove + """ + + name: str = "common_word" + common_words: List[str] = COMMON_WORDS_EN + n: int = 2 + + def __call__(self, text): + words = text.split() + common_word_counter = 0 + # count the number of common words + for word in words: + if word.lower() in self.common_words: + common_word_counter += 1 + if common_word_counter >= self.n: + return False + # otherwise remove + return True diff --git a/easyllm/data/filters/cookie_banner.py b/easyllm/data/filters/cookie_banner.py new file mode 100644 index 0000000..91ed1bb --- /dev/null +++ b/easyllm/data/filters/cookie_banner.py @@ -0,0 +1,53 @@ +import re + +from pydantic import BaseModel + +policy_substrings = [ + "terms of use", + "privacy policy", + "cookie policy", + "uses cookies", + "privacy overview", + "use of cookies", + "use cookies", + "privacy & cookies policy", + "privacy and cookies policy", + "This website uses cookies to improve your experience while you " + "navigate through the website. Out of these cookies, the cookies " + "that are categorized as necessary are stored on your browser as they " + "are essential for the working of basic functionalities of the website. " + "We also use third-party cookies that help us analyze and understand how " + "you use this website. These cookies will be stored in your browser only " + "with your consent. You also have the option to opt-out of these " + "cookies. But opting out of some of these cookies may have an effect " + "on your browsing experience.".lower(), + "Necessary cookies are absolutely essential for the website to " + "function properly. This category only includes cookies that " + "ensures basic functionalities and security features of the website. " + "These cookies do not store any personal information.".lower(), + "Any cookies that may not be particularly necessary for the website " + "to function and is used specifically to collect user personal data " + "via analytics, ads, other embedded contents are termed as non-necessary " + "cookies. It is mandatory to procure user consent prior to running these " + "cookies on your website.".lower(), + "This site uses cookies, including for analytics, personalization, and " + "advertising purposes. For more information or to change your " + "cookie settings, click here.".lower(), + "If you continue to browse this site without changing your cookie " + "settings, you agree to this use. AcceptRead More".lower(), +] + + +class CookieBannerFilter(BaseModel): + """ + Ref: C4 Raffel et al. + Desc: Removes documents if more than 40% of the documents include terms for cookies, tos, privacy policy, etc. Requires external list. + """ + + name: str = "cookie_banner" + regex: re.Pattern = re.compile(r"(terms of use|privacy policy|copyright|all rights reserved)", re.IGNORECASE) + remove_percentage: float = 0.4 + + def __call__(self, text): + # check if the regex matches + raise NotImplementedError("CookieBannerFilter not implemented yet") diff --git a/easyllm/data/filters/digit_to_character.py b/easyllm/data/filters/digit_to_character.py new file mode 100644 index 0000000..7916afc --- /dev/null +++ b/easyllm/data/filters/digit_to_character.py @@ -0,0 +1,22 @@ +import re + +from pydantic import BaseModel + + +class DigitToCharacter(BaseModel): + """ + Desc: If more than 20% of the document are digits then remove + """ + + name: str = "digit_to_character" + remove_percentage: float = 0.2 + + def __call__(self, text): + digits = re.findall(r"\d", text) + num_digits = len(digits) + total_chars = len(text) + # check if there are any characters in the text + if num_digits / total_chars > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/kenlm_ppl.py b/easyllm/data/filters/kenlm_ppl.py new file mode 100644 index 0000000..8cb4e1b --- /dev/null +++ b/easyllm/data/filters/kenlm_ppl.py @@ -0,0 +1,200 @@ +import importlib.util +import re +import unicodedata +from typing import Dict + +from huggingface_hub import hf_hub_download +from pydantic import BaseModel, ConfigDict + +_kenlm = importlib.util.find_spec("kenlm") is not None +_sentencepiece = importlib.util.find_spec("sentencepiece") is not None + +if _kenlm or not _sentencepiece: + import kenlm + import sentencepiece + + +class SentencePiece: + def __init__( + self, + model: str, + ): + super().__init__() + self.sp = sentencepiece.SentencePieceProcessor() + self.sp.load(str(model)) + + def do(self, text: dict) -> dict: + tokenized = self.sp.encode_as_pieces(text) + return " ".join(tokenized) + + +class KenlmModel: + digit_re: re.Pattern[str] = re.compile(r"\d") + unicode_punct: Dict[str, str] = { + ",": ",", + "。": ".", + "、": ",", + "„": '"', + "”": '"', + "“": '"', + "«": '"', + "»": '"', + "1": '"', + "」": '"', + "「": '"', + "《": '"', + "》": '"', + "´": "'", + "∶": ":", + ":": ":", + "?": "?", + "!": "!", + "(": "(", + ")": ")", + ";": ";", + "–": "-", + "—": " - ", + ".": ". ", + "~": "~", + "’": "'", + "…": "...", + "━": "-", + "〈": "<", + "〉": ">", + "【": "[", + "】": "]", + "%": "%", + "►": "-", + } + unicode_punct_re: re.Pattern = re.compile(f"[{''.join(unicode_punct.keys())}]") + non_printing_chars_re: re.Pattern = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]") + model: kenlm.Model = None + tokenizer: SentencePiece = None + accent: bool = False + case: bool = False + numbers: bool = True + punct: int = 1 + + def __init__( + self, + model_path: str, + tokenizer_path: str, + lower_case: bool = False, + remove_accents: bool = False, + normalize_numbers: bool = True, + punctuation: int = 1, + ): + self.model = kenlm.Model(model_path) + self.tokenizer = SentencePiece(tokenizer_path) + self.accent = remove_accents + self.case = lower_case + self.numbers = normalize_numbers + self.punct = punctuation + + @classmethod + def from_pretrained( + cls, + language_or_path: str, + ): + try: + model = hf_hub_download("philschmid/kenlm", filename=f"wikipedia/{language_or_path}.arpa.bin") + tokenizer = hf_hub_download("philschmid/kenlm", filename=f"wikipedia/{language_or_path}.sp.model") + except Exception: + raise ValueError( + f"KenLM model for {language_or_path} not found at https://huggingface.co/philschmid/kenlm. Please train your own model and upload it to the hub." + ) from None + + return cls( + model, + tokenizer, + False, + False, + True, + 1, + ) + + def pp(self, log_score, length): + return 10.0 ** (-log_score / length) + + def get_perplexity(self, doc: str, normalize_cc_net: bool = True): + if normalize_cc_net: + doc = self.normalize( + doc, + accent=self.accent, + case=self.case, + numbers=self.numbers, + punct=self.punct, + ) + # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline + doc = self.tokenizer.do(doc) + doc_log_score, doc_length = 0, 0 + for line in doc.split("\n"): + log_score = self.model.score(line) + length = len(line.split()) + 1 + doc_log_score += log_score + doc_length += length + return round(self.pp(doc_log_score, doc_length), 1) + + def normalize( + self, + line: str, + accent: bool = True, + case: bool = True, + numbers: bool = True, + punct: int = 1, + ) -> str: + line = line.strip() + if not line: + return line + if case: + line = line.lower() + if accent: + line = self.strip_accents(line) + if numbers: + line = self.digit_re.sub("0", line) + if punct == 1: + line = self.replace_unicode_punct(line) + elif punct == 2: + line = self.remove_unicode_punct(line) + line = self.remove_non_printing_char(line) + return line + + def strip_accents(self, line: str) -> str: + """Strips accents from a piece of text.""" + nfd = unicodedata.normalize("NFD", line) + output = [c for c in nfd if unicodedata.category(c) != "Mn"] + if len(output) == line: + return line + return "".join(output) + + def replace_unicode_punct(self, text: str) -> str: + return "".join(self.unicode_punct.get(c, c) for c in text) + + def remove_unicode_punct(self, text: str) -> str: + """More aggressive version of replace_unicode_punct but also faster.""" + return self.unicode_punct_re.sub("", text) + + def remove_non_printing_char(self, text: str) -> str: + return self.non_printing_chars_re.sub("", text) + + +class PerplexityFilter(BaseModel): + model: KenlmModel = None + min_threshold: int = 0 + max_threshold: int = 1000 + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, language: str, min_threshold: int = 0, max_threshold: int = 1000): + super().__init__() + self.min_threshold = min_threshold + self.max_threshold = max_threshold + self.model = KenlmModel.from_pretrained(language) + + def __call__(self, doc: str) -> bool: + # returns True if the perplexity of the document outside of the threshold, + # meaning smaller than min_threshold or larger than max_threshold + perplexity = self.model.get_perplexity(doc) + if perplexity < self.min_threshold or perplexity > self.max_threshold: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/length.py b/easyllm/data/filters/length.py new file mode 100644 index 0000000..51646f2 --- /dev/null +++ b/easyllm/data/filters/length.py @@ -0,0 +1,20 @@ + +from pydantic import BaseModel + + +class LengthFilter(BaseModel): + """ + Desc: Removes documents below or above a certain length of words + """ + + name: str = "length" + min_length: int = 10 + max_length: int = 1_000_000 + + def __call__(self, text): + num_words = len(text.split()) + + if num_words < self.min_length or num_words > self.max_length: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/longword.py b/easyllm/data/filters/longword.py new file mode 100644 index 0000000..f98ba38 --- /dev/null +++ b/easyllm/data/filters/longword.py @@ -0,0 +1,20 @@ + +from pydantic import BaseModel + + +class LongWordFilter(BaseModel): + """ + Ref: C4 Raffel et al. + Desc: If document includes words with > 1000 character are removed, e.g. js or minified files. + """ + + name: str = "long_word" + max_length: int = 1000 + + def __call__(self, text): + words = text.split() + max_len = max(len(word) for word in words) + if max_len > self.max_length: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/n_gram.py b/easyllm/data/filters/n_gram.py new file mode 100644 index 0000000..5523be3 --- /dev/null +++ b/easyllm/data/filters/n_gram.py @@ -0,0 +1,32 @@ +from collections import Counter +from itertools import chain + +from pydantic import BaseModel + + +def get_ngrams(input_list, n): + return list(zip(*[input_list[i:] for i in range(n)])) + + +class TopNGramsFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If the document shrinks by > 20% after removing top n-grams then remove + """ + + name: str = "top_n_grams" + remove_percentage: float = 0.2 + n: int = 2 + + def __call__(self, text): + words = text.split() + if len(words) <= self.n: + return True + ngrams = get_ngrams(words, self.n) + n_grams = Counter(chain(ngrams)) + most_common = n_grams.most_common(1)[0][0] + + if n_grams[most_common] / len(n_grams) > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/non_alpha_numeric.py b/easyllm/data/filters/non_alpha_numeric.py new file mode 100644 index 0000000..81815f3 --- /dev/null +++ b/easyllm/data/filters/non_alpha_numeric.py @@ -0,0 +1,27 @@ +import re + +from pydantic import BaseModel + + +class NonAlphaNumericFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If more than 20% of the document is non-alphanumeric then remove + """ + + name: str = "non_alpha_numeric" + regex: re.Pattern = re.compile(r"[^a-zA-Z0-9\s]") + remove_percentage: float = 0.2 + + def __call__(self, text): + num_characters = len(text) + # check if there are any characters in the text + if num_characters == 0: + return True + # calculate the percentage of non-alphanumeric characters + percentage = 1 - ((num_characters - len(self.regex.findall(text))) / num_characters) + # if the percentage is greater than the remove_percentage then remove + if percentage > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/parantheses_ration.py b/easyllm/data/filters/parantheses_ration.py new file mode 100644 index 0000000..02c8e76 --- /dev/null +++ b/easyllm/data/filters/parantheses_ration.py @@ -0,0 +1,23 @@ +import re + +from pydantic import BaseModel + + +class ParenthesesRationFilter(BaseModel): + """ + Desc: If more than 10% of the document are Parentheses then remove + """ + + name: str = "parentheses_ratio" + regex: re.Pattern = re.compile(r"\[|\]|\(|\)|{|}|⟨|⟩") + remove_percentage: float = 0.1 + + def __call__(self, text): + # parentheses characters + parentheses_count = len(self.regex.findall(text)) + sentence_length = len(text) + # check if the ratio of parentheses to text is greater than the remove percentage + if parentheses_count / sentence_length > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/punctuation.py b/easyllm/data/filters/punctuation.py new file mode 100644 index 0000000..17da9a2 --- /dev/null +++ b/easyllm/data/filters/punctuation.py @@ -0,0 +1,55 @@ +from typing import List + +from pydantic import BaseModel + + +class PunctuationFilter(BaseModel): + """ + Ref: C4 Raffel et al. + Desc: If less than 15% of the sentences end with a punctuation mark then remove + """ + + name: str = "punctuation" + punctuations: List[str] = [".", "!", "?"] + remove_percentage: float = 0.15 + + def __call__(self, text): + sentences = text.split("\n") + # count the number of sentences ending with a punctuation mark + punc_counter = 0 + for sentence in sentences: + for punc in self.punctuations: + if sentence.endswith(punc): + punc_counter += 1 + break + # check if the ratio of sentences not ending with a punctuation mark is greater than the remove percentage + if punc_counter / len(sentences) < self.remove_percentage: + return True + # otherwise keep + return False + + +class EllipsisFilter(BaseModel): + """ + Ref: C4 Raffel et al. + Desc: If more than 30% of the sentences endwith an elipsis then remove + """ + + name: str = "ellipsis" + ellipsis: List[str] = ["...", "[...]", "…", "(...)", "[…]", "-»", "read more..", "read more"] + remove_percentage: float = 0.3 + + def __call__(self, text): + sentences = text.split("\n") + # count the number of sentences ending with an ellipsis + ellipsis_counter = 0 + for sentence in sentences: + for ellipsis in self.ellipsis: + if sentence.endswith(ellipsis): + ellipsis_counter += 1 + break + # check if the ratio of sentences ending with an ellipsis is greater than the remove percentage + if ellipsis_counter / len(sentences) > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/repeating.py b/easyllm/data/filters/repeating.py new file mode 100644 index 0000000..a37f5ca --- /dev/null +++ b/easyllm/data/filters/repeating.py @@ -0,0 +1,51 @@ +from pydantic import BaseModel + + +class RepeatedLinesFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If the document shrinks by > 30% after removing repeated lines then remove + """ + + name: str = "repeated_lines" + remove_percentage: float = 0.3 + + def __call__(self, text): + # split the text into lines + lines = text.split("\n") + # remove empty lines + lines = [line for line in lines if line.strip()] + if len(lines) == 0: + return True + # remove repeated lines + unique_lines = list(set(lines)) + # calculate the percentage of lines removed + if len(unique_lines) / len(lines) < self.remove_percentage: + return True + # otherwise keep + return False + + +class RepeatedParagraphFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If the document shrinks by > 30% after removing repeated paragraphs then remove + """ + + name: str = "repeated_paragraph" + remove_percentage: float = 0.3 + + def __call__(self, text): + # split the text into lines + paragraphes = text.split("\n\n") + # remove empty paragraph + paragraphes = [p for p in paragraphes if p.strip()] + if len(paragraphes) == 0: + return True + # remove repeated paragraphes + unique_paragraphes = list(set(paragraphes)) + # calculate the percentage of lines removed + if len(unique_paragraphes) / len(paragraphes) < self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/url_ratio.py b/easyllm/data/filters/url_ratio.py new file mode 100644 index 0000000..4571982 --- /dev/null +++ b/easyllm/data/filters/url_ratio.py @@ -0,0 +1,24 @@ +import re + +from pydantic import BaseModel + + +class UrlRatioFilter(BaseModel): + """ + Desc: If more than 20% of the document are urls then remove + """ + + name: str = "url_ratio" + regex: re.Pattern[ + str + ] = r"https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)" + remove_percentage: float = 0.2 + + def __call__(self, text): + # find all urls + urls = re.findall(self.regex, text) + # check if the ratio of urls to words is greater than the remove percentage + if len(urls) / len(text.split()) > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/whitespace_ration.py b/easyllm/data/filters/whitespace_ration.py new file mode 100644 index 0000000..e9ff23a --- /dev/null +++ b/easyllm/data/filters/whitespace_ration.py @@ -0,0 +1,23 @@ +import re + +from pydantic import BaseModel + + +class WhitespaceRatioFilter(BaseModel): + """ + Desc: If more than 25% of the document are bulletpoints then remove + """ + + name: str = "whitespace_ratio" + regex: re.Pattern = re.compile(r"\s") + remove_percentage: float = 0.25 + + def __call__(self, text): + # whitespace characters + whitespace_count = len(self.regex.findall(text)) + text_length = len(text) + # check if the ratio of whitespace to text is greater than the remove percentage + if whitespace_count / text_length > self.remove_percentage: + return True + # otherwise keep + return False diff --git a/easyllm/data/filters/words_to_symbol.py b/easyllm/data/filters/words_to_symbol.py new file mode 100644 index 0000000..7539dec --- /dev/null +++ b/easyllm/data/filters/words_to_symbol.py @@ -0,0 +1,33 @@ +import re + +from pydantic import BaseModel + + +class SymbolToWordFilter(BaseModel): + """ + Ref: Gopher (Rae et al., 2021) + Desc: If more than 10% of the document are symbols (hashes [#] or ellipsis (...)) then remove + """ + + name: str = "symbol_to_word" + regex: re.Pattern = r"(\#+|(\.{3,}))(?!\w)" + remove_percentage: float = 0.1 + + def __call__(self, text: str): + num_hashes = len(re.findall(r"\#+", text)) + num_ellipses = len(re.findall(r"\.{3,}", text)) + num_words = len(re.findall(r"\w+", text)) + + # check if there are any words in the text + if num_words == 0: + return True + + hash_ratio = num_hashes / num_words + ellipses_ratio = num_ellipses / num_words + + # if the percentage is greater than the remove_percentage then remove + if hash_ratio > self.remove_percentage or ellipses_ratio > self.remove_percentage: + return True + + # otherwise keep + return False diff --git a/notebooks/data-filter.ipynb b/notebooks/data-filter.ipynb new file mode 100644 index 0000000..ab940af --- /dev/null +++ b/notebooks/data-filter.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to use EasyLLM Quality data filters\n", + "\n", + "EasyLLMs `data` package adds quality filters for preprocessing text data for improved pretraining. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"easyllm[data]\" --upgrade" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Perplexity filtering\n", + "\n", + "Perplexity filtering can be used to improve model quality, coherence, and training efficiency by removing confusing text segments and focusing model learning on more standard, comprehensible language.\n", + "Perplexity filtering is implemented using `KenLM` models trained on wikipedia. You just need to provide your language id, e.g. `de` and your perplexity `min_threshold` and `max_threshold` the filter will return `True` if the perplexity of the text outside of the threshold `False` otherwise.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "341.3\n", + "46793.5\n" + ] + } + ], + "source": [ + "from easyllm.data.filters import PerplexityFilter\n", + "\n", + "ppl = PerplexityFilter(\"en\",min_threshold=10,max_threshold=1000)\n", + "\n", + "# Get perplexity\n", + "print(ppl.model.get_perplexity(\"I am very perplexed\"))\n", + "# 341.3 (low perplexity, since sentence style is formal and with no grammar mistakes)\n", + "\n", + "print(ppl.model.get_perplexity(\"im hella trippin\"))\n", + "# 46793.5 (high perplexity, since the sentence is colloquial and contains grammar mistakes)\n", + "\n", + "# testing the filter\n", + "assert ppl(\"I am very perplexed\") == False\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NonAlphaNumericFilter\n", + "\n", + "The `NonAlphaNumericFilter` removes documents based on the number of non-alphanumeric characters in the document. Based on [Gopher (Rae et al., 2021)](https://arxiv.org/pdf/2112.11446.pdf), if the document has more then 20% non-alphanumeric characters, it is removed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import NonAlphaNumericFilter\n", + "\n", + "nam = NonAlphaNumericFilter()\n", + "\n", + "# not filtered\n", + "assert nam(\"This is a test\") == False\n", + "\n", + "# filtered\n", + "assert nam(\"This is a test!!!!!!!\") == True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SymbolToWordFilter\n", + "\n", + "The `SymbolToWordFilter` removes any document with a symbol-to-word ratio greater than 0.1 for either the hash symbol or the ellipsis. Based on [Gopher (Rae et al., 2021)](https://arxiv.org/pdf/2112.11446.pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import SymbolToWordFilter\n", + "\n", + "stw = SymbolToWordFilter()\n", + "\n", + "assert stw(\"This is a test\") == False\n", + "\n", + "assert stw(\"spam#spam#spam#spam#spam#spam#spam#spam\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NumbersToCharacterFilter\n", + "\n", + "The `NumbersToCharacterFilter` removes any document where the 20% of the document are numbers." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import DigitToCharacter\n", + "\n", + "ntw = DigitToCharacter()\n", + "\n", + "assert ntw(\"Hello 123 world 456 this text 789 contains 1234 numbers more words\") == False\n", + "\n", + "assert ntw(\"Hello 34534 34534 \") == True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## UrlRatioFilter\n", + "\n", + "The `UrlRatioFilter` removes any document where 20% of the document is a URL." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import UrlRatioFilter \n", + "\n", + "ur = UrlRatioFilter()\n", + "\n", + "assert ur(\"https://www.google.com\") == True\n", + "\n", + "assert ur(\"Example text with some urls http://www.example.com and more text https://www.example2.com and more text\") == False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BulletpointRatioFilter \n", + "\n", + "The `BulletpointRatioFilter` removes documents that have more than 90% bulletpoints. Based on [Gopher (Rae et al., 2021)](https://arxiv.org/pdf/2112.11446.pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import BulletpointRatioFilter\n", + "\n", + "br = BulletpointRatioFilter()\n", + "\n", + "assert br(\"This is a text with \\n- some bullets but\\nnot all\") == False\n", + "\n", + "assert br(\"- some bullets and\\n- some more\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## WhitespaceRatioFilter\n", + "\n", + "The `WhitespaceRatioFilter` is a filter that removes documents that more than 25% of the text is whitespace.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import WhitespaceRatioFilter\n", + "\n", + "wr = WhitespaceRatioFilter()\n", + "\n", + "assert wr(\"This is a test\") == False\n", + "\n", + "assert wr(\"Hello world! This text has extra whitespace.\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ParenthesesRationFilter\n", + "\n", + "The `ParenthesesRationFilter` is a filter that removes all sentences that have a parentheses ratio greater than 10%." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import ParenthesesRationFilter\n", + "\n", + "pr = ParenthesesRationFilter()\n", + "\n", + "assert pr(\"This is a normal sentence\") == False\n", + "\n", + "assert pr(\"This a (with ) ] {(e)\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LongWordFilter\n", + "\n", + "The `LongWordFilter` is a filter that removes documents that include words longer > 1000 character, e.g. js minfied files." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import LongWordFilter\n", + "\n", + "lw = LongWordFilter()\n", + "\n", + "assert lw(\"This is a test\") == False\n", + "\n", + "assert lw(f\"This is a test with a {'longword'*500}\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LengthFilter\n", + "\n", + "The `LengthFilter` removes documents below or above a certain number of words. Not tokens since its more expensive to compute." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import LengthFilter\n", + "\n", + "l = LengthFilter(min_length=1, max_length=100)\n", + "\n", + "assert l(\"hello world\") == False\n", + "\n", + "assert l(\"hello world \" * 100) == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RepeatedParagraphFilter, RepeatedLinesFilter\n", + "\n", + "The `RepeatedParagraphFilter` & `RepeatedLinesFilter` remove documents which have more than 30% repeated lines or paragraphs. Based on [Gopher (Rae et al., 2021)](https://arxiv.org/pdf/2112.11446.pdf) " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import RepeatedLinesFilter, RepeatedParagraphFilter\n", + "\n", + "rl = RepeatedLinesFilter()\n", + "rp = RepeatedParagraphFilter()\n", + "\n", + "assert rl(\"hello\\nworld\") == False\n", + "assert rl(\"hello\\nhello\\nhello\\nhello\") == True\n", + "\n", + "assert rp(\"hello\\n\\nworld\") == False\n", + "assert rp(\"hello\\n\\nhello\\n\\nhello\\n\\nhello\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TopNGramsFilter\n", + "\n", + "The `TopNGramsFilter` removes the document if the top n-gram makes more than 20% of the document." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import TopNGramsFilter\n", + "\n", + "tng = TopNGramsFilter()\n", + "\n", + "assert tng(\"This is a test for a longer sentence\") == False \n", + "\n", + "assert tng(\"The quick brown fox jumps over the lazy dog The quick brown\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## PunctuationFilter & EllipsisFilter\n", + "\n", + "The `PunctuationFilter` & `EllipsisFilter` removes the document if more than 15% of the \"linebreaks\" don't contain any punctuation or if more than 30% of the \"linebreaks\" contain an ellipsis." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import PunctuationFilter, EllipsisFilter\n", + "\n", + "pf = PunctuationFilter()\n", + "\n", + "assert pf(\"This is a sentence.\") == False\n", + "\n", + "assert pf(\"This is a sentence\\n But is not one.\\nNo oneyet.\") == True\n", + "\n", + "ef = EllipsisFilter()\n", + "\n", + "assert ef(\"This is a sentence.\") == False\n", + "\n", + "assert ef(\"This is a sentence\\n But is not one....\") == True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CommonWordFilter\n", + "\n", + "The `CommonWordFilter` removes documents if they don't include atleast 2 common words." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from easyllm.data.filters import CommonWordFilter\n", + "\n", + "cw = CommonWordFilter()\n", + "\n", + "assert cw(\"This is a sentence with a common word.\") == False\n", + "\n", + "assert cw(\"cat dog mouse\") == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/datasets/filter-dataset.ipynb b/notebooks/datasets/filter-dataset.ipynb new file mode 100644 index 0000000..a55f94a --- /dev/null +++ b/notebooks/datasets/filter-dataset.ipynb @@ -0,0 +1,2316 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip uninstall easyllm -y\n", + "%pip install git+https://github.com/philschmid/easyllm.git@datafilter --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset('philschmid/oscar-2301-de-minhash-dedup',split=\"train\")\n", + "# ds = load_dataset('wikipedia','20220301.de',split=\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perplexity filtering \n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nach § 80 Abs. 5 Satz 1 Halbsatz 2 VwGO kann das Gericht der Hauptsache die aufschiebende Wirkung der Klage ganz oder teilweise wiederherstellen. Ist die sofortige Vollziehung von der Behörde den formellen Anforderungen des § 80 Abs. 3 Satz 1 VwGO genügend angeordnet worden, so entscheidet das Gericht nach § 80 Abs. 5 Satz 1 Halbsatz 2 VwGO über die Wiederherstellung der aufschiebenden Wirkung der Klage auf der Grundlage einer eigenen Abwägung des Interesses des Antragstellers, von der Vollziehung des angefochtenen Verwaltungsakts bis zur endgültigen Entscheidung über seine Rechtmäßigkeit verschont zu bleiben, gegen das besondere öffentliche Interesse an dessen sofortiger Vollziehung (vgl. BVerwG, Beschl. v. 19.12.2014 - 7 VR 5.14 -, juris Rn. 9; Nds. OVG, Beschl. v. 10.09.2014 - 8 ME 87/14 -, juris Rn. 2). Im Rahmen der Interessenabwägung haben die Erfolgsaussichten des in der Hauptsache eingelegten Rechtsbehelfs eine entscheidende Bedeutung. Ergibt sich bei der im Rahmen des vorläufigen Rechtsschutzes gebotenen, aber grundsätzlich auch ausreichenden (vgl. Nds. OVG, Beschl. v. 16.8.2017 - 13 ME 173/17 -, juris Rn. 4, vgl. auch Beschl. v. 24.01.2018 - 7 ME 110/17 -, juris Rn. 28) summarischen Überprüfung, dass der Rechtsbehelf in der Hauptsache keinen Erfolg haben wird, weil sich der angegriffene Verwaltungsakt als offensichtlich rechtmäßig erweist, so überwiegt regelmäßig das öffentliche Interesse an der sofortigen Vollziehung des Verwaltungsakts. Erweist sich der Rechtsbehelf bei summarischer Überprüfung demgegenüber als offensichtlich erfolgreich, überwiegt regelmäßig das Interesse des Adressaten des Verwaltungsakts, von dessen Vollziehung vorerst verschont zu bleiben. Stellen sich die Erfolgsaussichten des Rechtsbehelfs hingegen als offen dar, so ist eine Abwägung der widerstreitenden Interessen erforderlich, bei der in Rechnung zu stellen ist, welche Gründe bei bestehender Unsicherheit im Hinblick auf die Erfolgsaussichten des Rechtsbehelfs für und gegen eine Aufrechterhaltung der sofortigen Vollziehung des Verwaltungsakts sprechen (vgl. Nds. OVG, Beschl. v. 10.5.2010 - 13 ME 181/09 -, juris Rn. 4). Außerdem ist zu berücksichtigen, dass die voraussichtliche Rechtmäßigkeit eines Verwaltungsakts für sich allein nur das allgemeine Interesse an seiner Vollziehung begründet, nicht aber zugleich auch deren, für die behördliche Anordnung nach § 80 Abs. 2 Satz 1 Nr. 4 VwGO erforderliche Dringlichkeit (vgl. grundlegend BVerfG, Beschl. v. 27.4.2005 - 1 BvR 223/05 -, NVwZ 2005, 1303; Beschl. v. 18.7.1973, - 1 BvR 23/73 -, BVerfGE 35, 382, 402; Nds. OVG, Beschl. v. 10.9.2014, a.a.O.; Finkelnburg/Dombert/Külpmann, Vorläufiger Rechtsschutz im Verwaltungsstreitverfahren, 7. Aufl., Rn. 757 f. m.w.N.).\n" + ] + } + ], + "source": [ + "print(ds[456][\"text\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8071b0d5472949deabe06d5600f46054", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "add url (num_proc=128): 0%| | 0/53172498 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# df = ds.to_pandas()\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Example dataframe\n", + "def plot_distribution(dfs):\n", + " # Get summary stats and quartiles\n", + " q1 = dfs['perplexity'].quantile(.05)\n", + " q2 = dfs['perplexity'].quantile(.5)\n", + " q3 = dfs['perplexity'].quantile(.95)\n", + "\n", + " # Create line chart \n", + " counts, bins = np.histogram(dfs['perplexity'], bins=30000)\n", + " bin_centers = 0.5*(bins[1:] + bins[:-1])\n", + " plt.plot(bin_centers, counts)\n", + "\n", + " # Add vertical lines for quartiles \n", + " plt.axvline(x=q1, color='r')\n", + " plt.axvline(x=q2, color='g')\n", + " plt.axvline(x=q3, color='b')\n", + "\n", + " plt.title('Perplexity Distribution')\n", + " plt.xlabel('Perplexity')\n", + " plt.ylabel('Frequency')\n", + " plt.xscale('log')\n", + "\n", + " plt.show()\n", + "\n", + "plot_distribution(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "get some random samples from the dataset with low and high perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Low: 3.3\n", + "High: 1155.9099999999978\n" + ] + } + ], + "source": [ + "low = df.perplexity.quantile(0)\n", + "high = df.perplexity.quantile(0.9)\n", + "\n", + "print(f'Low: {low}')\n", + "print(f'High: {high}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_lowest sample:_\n", + "```\n", + "'Die Skulptur Madonna mit Kind in der katholischen Kirche St-Lucien in Angy, einer französischen Gemeinde im Département Oise in der Region Hauts-de-France, wurde im dritten Viertel des 14. Jahrhunderts geschaffen. Im Jahr 1912 wurde die gotische Skulptur als Monument historique in die Liste der geschützten Objekte (Base Palissy) in Frankreich aufgenommen.\\nDie 1,10 Meter hohe Skulptur aus Kalkstein ist farbig gefasst. Maria hält das Jesuskind auf dem linken Arm. Sein Gesicht wendet sich in Richtung des Betrachters. Maria, mit bäuerlichem Gesicht und roten Wangen, trägt auf ihrem Haupt eine Krone. Die vielen Falten von ihrem Kleid geben ihrer Erscheinung eine Fülle.'\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filterstrategy: \n", + "\n", + "1. AlphanumericFilter remove sentence with more than 20% alphanumeric\n", + "2. ParenthesesRationFilter remove sentence with more than 5% parentheses\n", + "3. PunctuationFilter remove sentence with more than 15% missing punctuation\n", + "4. EllipsisFilter remove sentence with more than 30% ellipsis\n", + "5. LengthFilter: filter short documets < 5 words\n", + "6. LongWordFilter: for js stuff\n", + "7. CommonWordFilter: check if coherent sentence maybe not needed\n", + "8. RepeatedLinesFilter: remove repeated lines 30%\n", + "9. WhitespaceRatioFilter: remove sentence with more than 25% whitespace\n", + "10. UrlRatioFilter: remove sentence with more than 20% url\n", + "11. PerplexityFilter: remove sentence with perplexity > 1000\n", + "\n", + "\n", + "TODO: find law example which is super long and to filter it " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "wikipedia_filters = [\n", + " NonAlphaNumericFilter(),\n", + " LengthFilter(min_length=10),\n", + " CommonWordFilter(common_words=COMMON_WORDS_DE),\n", + " UrlRatioFilter(),\n", + " PerplexityFilter(language=\"de\",min_threshold=0,max_threshold=perplexity_threshold)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "da7067303e814628917d39b02cab5c0e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "filter documents... (num_proc=128): 0%| | 0/53172498 [00:00 300_000:\n", + " return False\n", + " return True\n", + "\n", + "\n", + "# datasets filters keeps true elements, meaning if the filter is we want to set it to false\n", + "ds = ds.filter(apply_filters,num_proc=os.cpu_count(),\n", + " desc=\"filter documents...\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "37786e0a47ec4375a7495bc1c5ed7ff2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Saving the dataset (0/564 shards): 0%| | 0/44401239 [00:00