diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 36058317..7f86bd1f 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -43,4 +43,4 @@ jobs: - name: Test with pytest run: | # run tests in tests/ dir and only fail if there are failures or errors - pytest tests/ --verbose --failed-first --exitfirst --disable-warnings \ No newline at end of file + pytest tests/ --verbose --failed-first --exitfirst --disable-warnings diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 00000000..192657e2 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,82 @@ +import pytest +from prompting.agent import Persona +from prompting.agent import HumanAgent +from prompting.tasks import Task, QuestionAnsweringTask, SummarizationTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from prompting.tools import MockDataset, CodingDataset, WikiDataset, StackOverflowDataset, DateQADataset, MathDataset +from prompting.mock import MockPipeline + +""" +Things to test: + - Agent is initialized correctly + - Agent contains a persona + - Agent contains a task + - Agent can make queries + - Agent can make responses + + - Persona is initialized correctly + - Persona contains a mood + - Persona contains a tone + - Persona contains a topic + - Persona contains a subject + - Persona contains a description + - Persona contains a goal + - Persona contains a query + + - Task is initialized correctly + - Task contains a query + - Task contains a reference + - Task contains a context + - Task contains a complete flag + + +""" +TASKS = [ + QuestionAnsweringTask, + SummarizationTask, + #DebuggingTask, + #MathTask, + DateQuestionAnsweringTask, + ] +LLM_PIPELINE = MockPipeline("mock") +CONTEXTS = { + QuestionAnsweringTask: WikiDataset().next(), + SummarizationTask: WikiDataset().next(), + DebuggingTask: CodingDataset().next(), + MathTask: MathDataset().next(), + DateQuestionAnsweringTask: DateQADataset().next(), +} + +@pytest.mark.parametrize('task', TASKS) +def test_agent_creation_with_dataset_context(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + agent = HumanAgent(llm_pipeline=LLM_PIPELINE, task=task, begin_conversation=True) + assert agent is not None + +@pytest.mark.parametrize('task', TASKS) +def test_agent_contains_persona(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + agent = HumanAgent(llm_pipeline=LLM_PIPELINE, task=task, begin_conversation=True) + assert agent.persona is not None + +@pytest.mark.parametrize('task', TASKS) +def test_agent_contains_task(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + agent = HumanAgent(llm_pipeline=LLM_PIPELINE, task=task, begin_conversation=True) + assert agent.task is not None + +@pytest.mark.parametrize('task', TASKS) +def test_agent_can_make_queries(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + agent = HumanAgent(llm_pipeline=LLM_PIPELINE, task=task, begin_conversation=True) + assert agent.query is not None + +@pytest.mark.parametrize('task', TASKS) +def test_agent_can_make_challenges(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + agent = HumanAgent(llm_pipeline=LLM_PIPELINE, task=task) + assert agent.challenge is not None diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..8e1802e2 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,27 @@ +import pytest + +from prompting.tools import MockDataset, CodingDataset, WikiDataset, StackOverflowDataset, DateQADataset, MathDataset + + + + +DATASETS = [ + MockDataset, + CodingDataset, + WikiDataset, + StackOverflowDataset, + DateQADataset, + MathDataset, +] + + +@pytest.mark.parametrize('dataset', DATASETS) +def test_create_task(dataset): + data = dataset() + assert data is not None + + +@pytest.mark.parametrize('dataset', DATASETS) +def test_create_task(dataset): + data = dataset() + assert data.next() is not None \ No newline at end of file diff --git a/tests/test_dataset_task_integration.py b/tests/test_dataset_task_integration.py new file mode 100644 index 00000000..7df3c569 --- /dev/null +++ b/tests/test_dataset_task_integration.py @@ -0,0 +1,50 @@ +import pytest +from prompting.tasks import Task, QuestionAnsweringTask, SummarizationTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from prompting.tools import MockDataset, CodingDataset, WikiDataset, StackOverflowDataset, DateQADataset, MathDataset +from prompting.mock import MockPipeline + + +""" +What we want: + +- The task is initialized correctly using dataset +- The task contains a query using dataset +- The task contains a reference answer using dataset +""" + + +TASKS = [ + QuestionAnsweringTask, + SummarizationTask, + #DebuggingTask, + #MathTask, + DateQuestionAnsweringTask, + ] +CONTEXTS = { + QuestionAnsweringTask: WikiDataset().next(), + SummarizationTask: WikiDataset().next(), + DebuggingTask: CodingDataset().next(), + MathTask: MathDataset().next(), + DateQuestionAnsweringTask: DateQADataset().next(), +} + +LLM_PIPELINE = MockPipeline("mock") + +@pytest.mark.parametrize('task', TASKS) +def test_task_creation_with_dataset_context(task: Task): + context = CONTEXTS[task] + task(llm_pipeline=LLM_PIPELINE, context=context) + assert task is not None + +@pytest.mark.parametrize('task', TASKS) +def test_task_contains_query(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + assert task.query is not None + +@pytest.mark.parametrize('task', TASKS) +def test_task_contains_reference(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + assert task.reference is not None + diff --git a/tests/test_persona.py b/tests/test_persona.py new file mode 100644 index 00000000..4f3097a5 --- /dev/null +++ b/tests/test_persona.py @@ -0,0 +1,14 @@ +import pytest +from prompting.persona import Persona, create_persona + +def test_persona_initialization(): + assert create_persona() is not None + +def test_persona_contains_mood(): + assert create_persona().mood is not None + +def test_persona_contains_tone(): + assert create_persona().tone is not None + +def test_persona_contains_profile(): + assert create_persona().profile is not None \ No newline at end of file diff --git a/tests/test_tasks.py b/tests/test_tasks.py new file mode 100644 index 00000000..1a5df0ea --- /dev/null +++ b/tests/test_tasks.py @@ -0,0 +1,85 @@ +import pytest +from prompting.tasks import Task, QuestionAnsweringTask, SummarizationTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from prompting.mock import MockPipeline + +""" +What we want to test for each task: +- The task is initialized correctly +- The task contains a query +- The task contains a reference answer +- Task contains a query_time +- Task contains a reference_time +- The task formats correctly +- All task fields are present as expected +- Tasks have reward definitions +""" + + +LLM_PIPELINE = MockPipeline("mock") +CONTEXT = {"text": "This is a context.", "title": "this is a title"} + +TASKS = [ + QuestionAnsweringTask, + SummarizationTask, + DebuggingTask, + MathTask, + DateQuestionAnsweringTask, + ] +CONTEXTS = { + QuestionAnsweringTask: {"text": "This is a context.", "title": "this is a title", "categories": ['some','categories']}, + SummarizationTask: {"text": "This is a context.", "title": "this is a title", "categories": ['some','categories']}, + DebuggingTask: {"code": "This is code","repo_name":'prompting',"path":'this/is/a/path', "language":'python'}, + MathTask: {"problem": "This is a problem","solution":'3.1415'}, + DateQuestionAnsweringTask: {"section": "Events", "event":"1953 - Battle of Hastings in UK", 'date':"1 January"}, +} + +# TODO: Math task only works when solution is floatable +# TODO: DateQA only accepts section in {Births, Deaths, Events} +# TODO: DateQA expect wiki entry for event + +@pytest.mark.parametrize('task', TASKS) +def test_create_task(task: Task): + context = CONTEXTS[task] + task(llm_pipeline=LLM_PIPELINE, context=context) + +@pytest.mark.parametrize('task', TASKS) +def test_task_contains_query(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + assert task.query is not None + +@pytest.mark.parametrize('task', TASKS) +def test_task_contains_reference(task: Task): + context = CONTEXTS[task] + task = task(llm_pipeline=LLM_PIPELINE, context=context) + assert task.reference is not None + +# @pytest.mark.parametrize('task', TASKS) +# def test_task_contains_reward_definition(task: Task): +# context = CONTEXTS[task] +# task = task(llm_pipeline=LLM_PIPELINE, context=context) +# assert task.reward_definition is not None + +# @pytest.mark.parametrize('task', TASKS) +# def test_task_contains_goal(task: Task): +# context = CONTEXTS[task] +# task = task(llm_pipeline=LLM_PIPELINE, context=context) +# assert task.goal is not None + +# @pytest.mark.parametrize('task', TASKS) +# def test_task_contains_desc(task: Task): +# context = CONTEXTS[task] +# task = task(llm_pipeline=LLM_PIPELINE, context=context) +# assert task.desc is not None + +# @pytest.mark.parametrize('task', TASKS) +# def test_task_contains_query_time(task: Task): +# context = CONTEXTS[task] +# task = task(llm_pipeline=LLM_PIPELINE, context=context) +# assert task.reference_time>=0 + +# @pytest.mark.parametrize('task', TASKS) +# def test_task_contains_reference_time(task: Task): +# context = CONTEXTS[task] +# task = task(llm_pipeline=LLM_PIPELINE, context=context) +# assert task.query_time>=0