Skip to content

Commit

Permalink
reformat how GenerationCleaner is called within the scope of the project
Browse files Browse the repository at this point in the history
  • Loading branch information
mccrindlebrian committed Jan 18, 2024
1 parent 566527a commit 76323eb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 22 deletions.
18 changes: 14 additions & 4 deletions prompting/tasks/date_qa.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
from dataclasses import dataclass
from prompting.tasks import Task
from prompting.utils.clean_generation import GenerationCleaner


@dataclass
class DateQuestionAnsweringTask(Task):
reward_definition = [
dict(name='date', weight = 1),
dict(name="date", weight=1),
]

def __init__(self, llm_pipeline, context, create_reference=True):
NAME = "date-based question answering"
self.cleaner = GenerationCleaner()
self.context = context
self.section = self.context["section"]
year, _, *event = self.context["event"].split()
self.context["event"] = " ".join(event)
options = {'Births':' was born ', 'Deaths':' died ', 'Events':' '}
query = self.context["event"].strip(".") + options[self.section] + 'on what date?'
options = {"Births": " was born ", "Deaths": " died ", "Events": " "}

query = (
self.context["event"].strip(".") + options[self.section] + "on what date?"
)
# query = self.cleaner.apply(generation=query, task_name = NAME) #Might not want to apply cleaning to query.

reference = self.context["date"] + ", " + year.strip()
reference = self.cleaner.apply(generation=reference, task_name=NAME)

super().__init__(
name="date-based question answering",
name=NAME,
desc="get help answering a question",
goal="to get the answer to the following question",
query=query,
Expand Down
12 changes: 8 additions & 4 deletions prompting/tasks/qa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from prompting.tasks import Task
from prompting.utils.clean_generation import GenerationCleaner

# TODO: introduce criteria for the query and reference answer (length, layout, etc.) and make these arguments
# TODO
Expand Down Expand Up @@ -46,25 +47,28 @@ class QuestionAnsweringTask(Task):
]

def __init__(self, llm_pipeline, context, create_reference=True):
NAME = "question-answering"
self.cleaner = GenerationCleaner()
self.context = context

self.query_system_prompt = QUERY_SYSTEM_PROMPT
self.query_prompt = QUERY_PROMPT_TEMPLATE.format(
context=self.context["text"]
)
self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=self.context["text"])

query = self.generate_query(llm_pipeline)
# query = self.cleaner.apply(generation=query, task_name=NAME) #Might not want to apply cleaning to query.

self.reference_system_prompt = REFERENCE_SYSTEM_PROMPT
self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(
context=self.context["text"], question=query
)
if create_reference:
reference = self.generate_reference(llm_pipeline)
reference = self.cleaner.apply(generation=reference, task_name=NAME)
else:
reference = None

super().__init__(
name="question-answering",
name=NAME,
desc="get help on answering a question",
goal="to get the answer to the following question",
query=query,
Expand Down
20 changes: 15 additions & 5 deletions prompting/tasks/summarization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass
from prompting.tasks import Task
from transformers import Pipeline
from prompting.utils.clean_generation import GenerationCleaner


# TODO: introduce criteria for the query and reference answer (length, layout, etc.) and make these arguments

Expand Down Expand Up @@ -29,25 +32,32 @@ class SummarizationTask(Task):
dict(name="relevance", threshold=None, weight=1.0),
]

def __init__(self, llm_pipeline, context, create_reference=True):
cleaner_pipeline = []

def __init__(self, llm_pipeline: Pipeline, context: str, create_reference=True):
NAME = "summarization"
self.cleaner = GenerationCleaner()
self.context = context

self.query_prompt = None
# NOTE: We do not perform an inference here and just use the article title as the query. This is because the article title is usually a good summary of the article itself.
# Query is just the article title
# NOTE: We do not perform an inference here and just use the article title as the query.
# This is because the article title is usually a good summary of the article itself.
# Query is just the article title.
query = self.context["title"]

self.reference_system_prompt = SUMMARIZATION_SYSTEM_PROMPT
self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(
context=self.context["text"]
)
if create_reference:
reference = self.generate_reference(llm_pipeline)
reference = self.generate_reference(llm=llm_pipeline)
reference = self.cleaner.apply(generation=reference, task_name=NAME)

else:
reference = None

super().__init__(
name="summarization",
name=NAME,
desc="get help with summarization",
goal="summarize the following topic",
query=query,
Expand Down
10 changes: 3 additions & 7 deletions prompting/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import List, Union
from prompting.llm import HuggingFaceLLM
from transformers import Pipeline

from prompting.utils.clean_generation import GenerationCleaner


Expand Down Expand Up @@ -40,9 +39,7 @@ class Task(ABC):
reference_prompt = ""
query_system_prompt = ""
query_prompt = ""

def __post_init__(self):
self.cleaner = GenerationCleaner()
cleaner = GenerationCleaner() # TODO: Remove?

def __str__(self):
return f"{self.__class__.__name__}(name={self.name!r}, desc={self.desc!r}, goal={self.goal!r}, query={self.query!r}, reference={self.reference!r}, topic={self.topic!r}, subtopic={self.subtopic!r}, tags={self.tags!r})"
Expand Down Expand Up @@ -73,10 +70,9 @@ def generate(self, system: str, prompt: str, llm: Pipeline) -> str:
"""Uses the llm to generate a response to a prompt"""

generation = HuggingFaceLLM(llm, system_prompt=system).query(prompt)
generation = self.cleaner.apply(generation=generation, task_name=self.name)
return generation

def generate_reference(self, llm) -> str:
def generate_reference(self, llm: Pipeline) -> str:
"""Generates a reference answer to be used for scoring miner completions"""
t0 = time.time()
if not self.static_reference:
Expand All @@ -91,7 +87,7 @@ def generate_reference(self, llm) -> str:
self.reference_time = time.time() - t0
return self.reference

def generate_query(self, llm) -> str:
def generate_query(self, llm: Pipeline) -> str:
"""Generates a query to be used for generating the challenge"""
t0 = time.time()
if not self.static_query:
Expand Down
6 changes: 4 additions & 2 deletions prompting/utils/clean_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def prune_ending(self, generation: str):
and not generation.endswith("!")
):
index = max(generation.rfind(char) for char in punctuation_chars)
return generation[: index + 1] #Go to the index of where the punctuation is, and include it (+1)
return generation[
: index + 1
] # Go to the index of where the punctuation is, and include it (+1)
else:
return generation

def remove_quotes(self, generation: str):
"""Remove quotes and spaces from the generation"""
return generation.strip("\"'")
Expand Down

0 comments on commit 76323eb

Please sign in to comment.