Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(extra): prompt tuning #79

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ This folder contains scripts that produce reproducible timings and evaluation me

## Setup environment

Before installing any package, make sure you have Python 3.8 or higher installed on your machine. From the root directory of the project, install the dependencies:
Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:

```bash
pip install -e '.[benchmarks]'
pip install -e '.[dev]'
```

## Benchmark list
Expand Down
13 changes: 13 additions & 0 deletions extra/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Extra

This folder contains scripts for researching stuff related to dbally. Links are provided where descriptions exist:

- [`Prompt tuning`](prompt_tuning/README.md)

## Setup environment

Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:

```bash
pip install -e '.[dev]'
```
42 changes: 42 additions & 0 deletions extra/prompt_tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Prompt tuning

This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally:

- `FILTERING_ASSESSOR` - assesses whether a question requires filtering.

All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database.

## Usage

Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`:

```bash
python evaluate.py program=filtering-assessor-baseline
```

Test multiple programs:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot
```

Compare prompt performance on multiple LLMs:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet
```

### Log to Neptune

Before running the evaluation with Neptune, configure the following environment variables:

```bash
export NEPTUNE_API_TOKEN="API_TOKEN"
export NEPTUNE_PROJECT="WORKSPACE_NAME/PROJECT_NAME"
```

Export evaluation results to Neptune:

```bash
python evaluate.py neptune=True
```
7 changes: 7 additions & 0 deletions extra/prompt_tuning/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- data: superhero
- llm: gpt-3.5-turbo
- program: filtering-assessor-baseline
- _self_

neptune: False
4 changes: 4 additions & 0 deletions extra/prompt_tuning/config/data/superhero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
path: "micpst/bird-iql"
split: "dev"
db_ids: ["superhero"]
difficulties: ["simple", "moderate", "challenging"]
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3-haiku.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-haiku-20240307
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3-opus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-opus-20240229
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3.5-sonnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-5-sonnet-20240620
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-3.5-turbo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-3.5-turbo
provider: OpenAI
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-4-turbo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-4-turbo
provider: OpenAI
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-4o.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-4o
provider: OpenAI
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: FILTERING_ASSESSOR
name: FilteringAssessorBaseline
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: FILTERING_ASSESSOR
name: FilteringAssessorCoT
101 changes: 101 additions & 0 deletions extra/prompt_tuning/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import logging
from enum import Enum
from pathlib import Path

import dspy
import hydra
import neptune
from dspy.evaluate import Evaluate
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig
from tuning.loaders import IQLGenerationDataLoader
from tuning.metrics import filtering_assess_acc
from tuning.programs import PROGRAMS
from tuning.utils import save, serialize_results

logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("anthropic").setLevel(logging.ERROR)
log = logging.getLogger(__name__)


class EvaluationType(Enum):
"""
Enum representing the evaluation type.
"""

FILTERING_ASSESSOR = "FILTERING_ASSESSOR"


EVALUATION_DATALOADERS = {
EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader,
}

EVALUATION_METRICS = {
EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc,
}


async def evaluate(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.

Args:
config: Hydra configuration.
"""
log.info("Starting evaluation: %s", config.program.name)

dataloader = EVALUATION_DATALOADERS[config.program.type](config)
metric = EVALUATION_METRICS[config.program.type]
program = PROGRAMS[config.program.name]()

dataset = await dataloader.load()

lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name)
dspy.settings.configure(lm=lm)

evaluator = Evaluate(
devset=dataset,
metric=metric,
num_threads=32,
display_progress=True,
return_outputs=True,
)
metric, results = evaluator(program)

log.info("Evaluation finished. Saving results...")

output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
results_file = output_dir / "results.json"
save(results_file, results=serialize_results(results))

log.info("Evaluation results saved under directory: %s", output_dir)

if config.neptune:
run = neptune.init_run()
run["sys/tags"].add(
[
config.program.type,
config.program.name,
*config.data.db_ids,
*config.data.difficulties,
]
)
run["config"] = stringify_unsupported(config)
run["evaluation/metrics/ACC"] = stringify_unsupported(metric)
run["evaluation/results.json"].upload(results_file.as_posix())


@hydra.main(config_path="config", config_name="config", version_base="3.2")
def main(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.

Args:
config: Hydra configuration.
"""
asyncio.run(evaluate(config))


if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
Empty file.
69 changes: 69 additions & 0 deletions extra/prompt_tuning/tuning/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterable, List

import dspy.datasets
from dspy import Example


class DataLoader(ABC):
"""
Data loader.
"""

def __init__(self, config: Dict) -> None:
self.config = config

@abstractmethod
async def load(self) -> Iterable:
"""
Load the data.

Returns:
The loaded data.
"""


class HuggingFaceDataLoader(DataLoader):
"""
Hugging Face data loader.
"""

async def load(self) -> List[Example]:
"""
Load the data from Hugging Face.

Returns:
The loaded data.
"""
dataloader = dspy.datasets.DataLoader()
dataset = dataloader.from_huggingface(
dataset_name=self.config.data.path, split=self.config.data.split, input_keys=("question",)
)
return [
data
for data in dataset
if data["question"]
if (
data["db_id"] in self.config.data.db_ids
if self.config.data.db_ids
else True and data["difficulty"] in self.config.data.difficulties
if self.config.data.difficulties
else True
)
]


class IQLGenerationDataLoader(HuggingFaceDataLoader):
"""
Data loader for IQL generation evaluation.
"""

async def load(self) -> List[Example]:
"""
Load the data from Hugging Face and filter out samples without views.

Returns:
The loaded data.
"""
dataset = await super().load()
return [data for data in dataset if data["view_name"]]
3 changes: 3 additions & 0 deletions extra/prompt_tuning/tuning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iql import filtering_assess_acc

__all__ = ["filtering_assess_acc"]
19 changes: 19 additions & 0 deletions extra/prompt_tuning/tuning/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict

from dspy import Prediction


def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
"""
IQL decision metric.

Args:
gold: The ground truth data point.
pred: The prediction.

Returns:
The decision metric.
"""
return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or (
(gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision
)
8 changes: 8 additions & 0 deletions extra/prompt_tuning/tuning/programs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .iql import FilteringAssessorBaseline, FilteringAssessorCoT

PROGRAMS = {
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline,
FilteringAssessorCoT.__name__: FilteringAssessorCoT,
}

__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"]
49 changes: 49 additions & 0 deletions extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dspy import ChainOfThought, Module, Predict, Prediction

from ..signatures.iql import CheckQuestionFiltering


class FilteringAssessorBaseline(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
super().__init__()
self.decide = Predict(CheckQuestionFiltering)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires filtering.

Args:
question: The question to assess.

Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")


class FilteringAssessorCoT(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
super().__init__()
self.decide = ChainOfThought(CheckQuestionFiltering)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires filtering.

Args:
question: The question to assess.

Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")
3 changes: 3 additions & 0 deletions extra/prompt_tuning/tuning/signatures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iql import CheckQuestionFiltering

__all__ = ["CheckQuestionFiltering"]
20 changes: 20 additions & 0 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dspy import InputField, OutputField, Signature


class CheckQuestionFiltering(Signature):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
"""

question = InputField(
prefix="Question: ",
)
decision = OutputField(
prefix="Decision: ",
desc=(
"indicates whether the answer to the question requires initial data filtering. "
"(Respond with True or False)"
),
)
Loading
Loading