-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
104 lines (95 loc) · 3.16 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from dataclasses import dataclass, field
from typing import List
def get_data_paths(dataset, model_tag):
return f"./data/{dataset}_train_{model_tag}.pt", f"./data/{dataset}_test_{model_tag}.pt"
@dataclass
class TrainScriptArguments:
model_name: str = field(
metadata={
"help": "Model identifier for models in huggingface/transformers."}
)
dataset: str = field(
metadata={
"help": "Dataset to be used for training. Assumes it exists in ./data/"}
)
train_data_path: str = field(
metadata={"help": "Path to training dataset"},
default=None
)
test_data_path: str = field(
metadata={"help": "Path to test dataset"},
default=None
)
use_defaults: bool = field(
metadata={
"help": "True to use default training values from script. Defaults to True"},
default=True
)
wandb_project: str = field(
metadata={
"help": "Name of wandb project for logging. Defaults to msc_question_generation"},
default="msc_question_generation"
)
tokenizer_name: str = field(
metadata={"help": "Tokenizer identifier, defaults to model_name"},
default=None
)
max_source_length: int = field(
metadata={"help": "Max input length for the source context + answer"},
default=512
)
max_target_length: int = field(
metadata={"help": "Max input length for the target question to generate"},
default=32
)
is_dryrun: bool = field(
metadata={
"help": "Set True to notify wandb that we are offline. Defaults to True."},
default=True
)
data_size: int = field(
metadata={
"help": "Percentage of data to use during training. Defaults to 100."},
default=100
)
absolute_data_size: int = field(
metadata={
"help": "Absolute number of training rows. Overrides data_size. Defaults to 100%."},
default=0
)
@dataclass
class DataArguments:
tokenizer_name: str = field(
metadata={"help": "Tokenizer identifier, from huggingface/transformers"}
)
dataset: str = field(
metadata={"help": "Dataset identifier from custom huggingface/nlp scripts"}
)
max_source_length: int = field(
metadata={"help": "Max input length for the source context + answer"},
default=512
)
max_target_length: int = field(
metadata={"help": "Max input length for the target question to generate"},
default=32
)
@dataclass
class EvalScriptArguments:
model_name: str = field(
metadata={
"help": "Model identifier for models in huggingface/transformers."}
)
test_sets: List[str] = field(
default_factory=list,
metadata={
"help": "Which sets to test on. If empty, test against all. Defaults to []"},
)
tokenizer_name: str = field(
metadata={"help": "Tokenizer identifier, defaults to model_name"},
default=None
)
model_path: str = field(
metadata={
"help": "Path to local model checkpoint or hf/transformers. Defaults to model_name"},
default=None
)