-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathdataset.py
40 lines (29 loc) · 1.48 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from datasets.twitter_customer_support.dataset import load_dataset as twitter_dataset
from datasets.twitter_customer_support.dataset import load_field as twitter_field
from datasets.twitter_customer_support.dataset import load_metadata as twitter_metadata
DATASET_IDX = 0
FIELD_IDX = 1
METADATA_IDX = 2
dataset_field_map = {
'twitter-applesupport': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-amazonhelp': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-delta': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-spotifycares': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-uber_support': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-all': (twitter_dataset, twitter_field, twitter_metadata),
'twitter-small': (twitter_dataset, twitter_field, twitter_metadata)
}
def get_dataset_tuple(args):
if args.dataset not in dataset_field_map:
raise ValueError("There is no \"%s\" dataset, available datasets are: (%s)"
% (args.dataset, ', '.join(dataset_field_map.keys())))
return dataset_field_map[args.dataset]
def dataset_factory(args, device):
dataset_tuple = get_dataset_tuple(args)
return dataset_tuple[DATASET_IDX](args, device)
def field_factory(args):
dataset_tuple = get_dataset_tuple(args)
return dataset_tuple[FIELD_IDX]()
def metadata_factory(args, vocab):
dataset_tuple = get_dataset_tuple(args)
return dataset_tuple[METADATA_IDX](vocab)