From 10a6b1c60068812a3695b5e5ed608c2d4f8f286a Mon Sep 17 00:00:00 2001 From: am9zZWY <46693545+am9zZWY@users.noreply.github.com> Date: Thu, 18 Jul 2024 18:25:20 +0200 Subject: [PATCH 1/2] Move summarization code to class --- engine/summarize.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/engine/summarize.py b/engine/summarize.py index a7df016..716260b 100644 --- a/engine/summarize.py +++ b/engine/summarize.py @@ -2,31 +2,20 @@ from pipeline import PipelineElement -# Load summarization pipeline -MODEL = "google/pegasus-xsum" -print(f"Loading summarization model {MODEL} ... This may take a few minutes.") -summarizer = pipeline("summarization", model=MODEL, tokenizer=MODEL) - - -def summarize_text(text: str, max_words: int = 15) -> str: - summary = summarizer(text, max_length=max_words * 2, min_length=max_words, do_sample=False)[0]['summary_text'] - - # Truncate to the specified number of words - words = summary.split() - if len(words) > max_words: - summary = ' '.join(words[:max_words]) + '...' - - return summary - class Summarizer(PipelineElement): """ Summarizes the input text. """ - def __init__(self): + def __init__(self, summary_model: str = "google/pegasus-xsum"): super().__init__("Summarizer") + # Load summarization pipeline + self.summary_model = summary_model + print(f"Loading summarization model {summary_model} ... This may take a few minutes.") + self.summarizer = pipeline("summarization", model=summary_model, tokenizer=summary_model) + async def process(self, data, link): """ Summarizes the input text. @@ -46,8 +35,19 @@ async def process(self, data, link): text = main_content.get_text() - summary = summarize_text(text) + summary = self._summarize_text(text) print(f"Summarized {link} to: {summary}") if not self.is_shutdown(): await self.propagate_to_next(summary) + + def _summarize_text(self, text: str, max_words: int = 15) -> str: + summary = self.summarizer(text, max_length=max_words * 2, min_length=max_words, do_sample=False)[0][ + 'summary_text'] + + # Truncate to the specified number of words + words = summary.split() + if len(words) > max_words: + summary = ' '.join(words[:max_words]) + '...' + + return summary From 0e7ae0fd32fd2bb474533e4d72808c68256ee2ef Mon Sep 17 00:00:00 2001 From: am9zZWY <46693545+am9zZWY@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:50:49 +0200 Subject: [PATCH 2/2] Clean imports --- engine/index.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/engine/index.py b/engine/index.py index 2adb7f7..547ff01 100644 --- a/engine/index.py +++ b/engine/index.py @@ -1,8 +1,6 @@ -import logging import duckdb -import pandas as pd -from custom_db import upsert_page_to_index, add_title_to_index, add_snippet_to_index, load_pages +from custom_db import load_pages from pipeline import PipelineElement