-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
105 lines (95 loc) · 2.9 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Various configuration options for the chatbot task.
This file is intended to be modified. You can go in and change any
of the variables to run different experiments.
"""
from __future__ import annotations
import transformers
from zeno_build.evaluation.text_features.exact_match import avg_exact_match, exact_match
from zeno_build.evaluation.text_features.length import (
chat_context_length,
input_length,
label_length,
output_length,
)
from zeno_build.evaluation.text_metrics.critique import (
avg_bert_score,
avg_chrf,
avg_length_ratio,
bert_score,
chrf,
length_ratio,
)
from zeno_build.experiments import search_space
from zeno_build.models.dataset_config import DatasetConfig
from zeno_build.models.lm_config import LMConfig
from zeno_build.prompts.chat_prompt import ChatMessages, ChatTurn
# Define the space of hyperparameters to search over.
space = search_space.CombinatorialSearchSpace(
{
"dataset_preset": search_space.Constant("local-fra"),
"model_preset": search_space.Categorical(
[
"gpt-3.5-turbo",
]
),
"prompt_preset": search_space.Discrete(
["tt-def"]
),
"temperature": search_space.Discrete([0.3]),
"context_length": search_space.Discrete([-1]),
"max_tokens": search_space.Constant(500),
"top_p": search_space.Constant(1.0),
}
)
# The number of trials to run
num_trials = 1
# The details of each dataset
dataset_configs = {
"flores200": DatasetConfig(
dataset=("facebook/flores", "swh_Latn-eng_Latn"),
split="devtest",
data_column="sentence_swh_Latn",
data_format="flores",
),
"local-fls": DatasetConfig( # This is the format we use.
dataset="",
split="devtest",
data_column="", # not relevant as we have plain txt files as inputs
data_format="local",
),
}
# The details of each model
model_configs = {
"text-davinci-003": LMConfig(provider="openai", model="text-davinci-003"),
"gpt-3.5-turbo": LMConfig(provider="openai_chat", model="gpt-3.5-turbo"),
"gpt-4": LMConfig(provider="openai_chat", model="gpt-4"),
}
# The details of the prompts - we incorporated the prompts with the dataset so we use the default which is just empty
prompt_messages: dict[str, ChatMessages] = {
"tt-def": ChatMessages(
messages=[
ChatTurn(
role="user",
content="",
),
]
),
}
# The functions to use to calculate scores for the hyperparameter sweep
sweep_distill_functions = [chrf]
sweep_metric_function = avg_chrf
# The functions used for Zeno visualization
zeno_distill_and_metric_functions = [
output_length,
input_length,
label_length,
chat_context_length,
chrf,
length_ratio,
bert_score,
exact_match,
avg_chrf,
avg_length_ratio,
avg_bert_score,
avg_exact_match,
]