Skip to content

Commit

Permalink
Adds task tests
Browse files Browse the repository at this point in the history
  • Loading branch information
steffencruz committed Jan 18, 2024
1 parent 5b4f737 commit a55b446
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from prompting.tasks import Task, QuestionAnsweringTask, SummarizationTask, DebuggingTask, MathTask, DateQuestionAnsweringTask
from prompting.mock import MockPipeline

"""
What we want to test for each task:
- The task is initialized correctly
- The task contains a query
- The task contains a reference answer
- Task contains a query_time
- Task contains a reference_time
- The task formats correctly
- All task fields are present as expected
- Tasks have reward definitions
"""


LLM_PIPELINE = MockPipeline("mock")
CONTEXT = {"text": "This is a context.", "title": "this is a title"}

TASKS = [
QuestionAnsweringTask,
SummarizationTask,
DebuggingTask,
MathTask,
DateQuestionAnsweringTask,
]
CONTEXTS = {
QuestionAnsweringTask: {"text": "This is a context.", "title": "this is a title", "categories": ['some','categories']},
SummarizationTask: {"text": "This is a context.", "title": "this is a title", "categories": ['some','categories']},
DebuggingTask: {"code": "This is code","repo_name":'prompting',"path":'this/is/a/path', "language":'python'},
MathTask: {"problem": "This is a problem","solution":'3.1415'},
DateQuestionAnsweringTask: {"section": "Events", "event":'1066 - Battle of Hastings in UK', 'date':"1 January 2021"},
}
# TODO: Math task only works when solution is floatable
# TODO: DateQA only accepts section in {Births, Deaths, Events}
# TODO: DateQA expect wiki entry for event

@pytest.mark.parametrize('task', TASKS)
def test_create_task(task: Task):
context = CONTEXTS[task]
task(llm_pipeline=LLM_PIPELINE, context=context)

@pytest.mark.parametrize('task', TASKS)
def test_task_contains_query(task: Task):
context = CONTEXTS[task]
task = task(llm_pipeline=LLM_PIPELINE, context=context)
assert task.query is not None

@pytest.mark.parametrize('task', TASKS)
def test_task_contains_reference(task: Task):
context = CONTEXTS[task]
task = task(llm_pipeline=LLM_PIPELINE, context=context)
assert task.reference is not None


# @pytest.mark.parametrize('task', TASKS)
# def test_task_contains_query_time(task: Task):
# context = CONTEXTS[task]
# task = task(llm_pipeline=LLM_PIPELINE, context=context)
# assert task.reference_time>=0

# @pytest.mark.parametrize('task', TASKS)
# def test_task_contains_reference_time(task: Task):
# context = CONTEXTS[task]
# task = task(llm_pipeline=LLM_PIPELINE, context=context)
# assert task.query_time>=0

0 comments on commit a55b446

Please sign in to comment.