This repository has been archived by the owner on Jan 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathtrain_ssg.py
115 lines (95 loc) · 3.63 KB
/
train_ssg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#
# Copyright (c) 2021 Facebook, Inc. and its affiliates.
#
# This file is part of NeuralDB.
# See https://github.com/facebookresearch/NeuralDB for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
from sentence_transformers import SentencesDataset, InputExample, SentenceTransformer
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from sentence_transformers.losses import ContrastiveLoss
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from ssg_utils import read_NDB, create_dataset
def is_valid_folder(parser, arg):
if not os.path.exists(arg):
parser.error("The file %s does not exist!" % arg)
else:
return arg
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="training ssg")
parser.add_argument(
"-i",
dest="folder",
required=True,
help="input data folder",
type=lambda x: is_valid_folder(parser, x),
)
parser.add_argument("-b", dest="batch_size", type=int, help="batch size", default=100)
parser.add_argument("-e", dest="epochs", type=int, help="number of epochs", default=10)
parser.add_argument("-o", dest="output", required=True, help="output address")
parser.add_argument("-d", dest="device", default="cuda:0", help="output address")
args = parser.parse_args()
folder = args.folder
batch_size = args.batch_size
epochs = args.epochs
output = args.output
device = args.device
# Define the model. Either from scratch of by loading a pre-trained model
model = SentenceTransformer("distilbert-base-nli-mean-tokens", device=device)
# read the train data
name = "train"
data_file = folder + "/" + name + ".jsonl"
db = read_NDB(data_file)
dataset = create_dataset(db)
train_examples = []
weights = []
for d in dataset:
texts = ["[SEP]".join(d[0]), "".join(d[1])]
label = d[2]
if label == 1:
weights.append(10)
else:
weights.append(1)
train_examples.append(InputExample(texts=texts, label=label))
# read the dev data
name = "dev"
data_file = folder + "/" + name + ".jsonl"
db = read_NDB(data_file)
dataset = create_dataset(db)
dev_examples = []
for d in dataset:
texts = ["[SEP]".join(d[0]), "".join(d[1])]
label = d[2]
dev_examples.append(InputExample(texts=texts, label=label))
train_loss = ContrastiveLoss(model)
# Define your train dataset, the dataloader and the train loss
train_dataset = SentencesDataset(train_examples, model)
sampler = WeightedRandomSampler(weights=weights, num_samples=len(train_examples))
train_dataloader = DataLoader(
train_dataset, sampler=sampler, shuffle=False, batch_size=batch_size
)
evaluator = BinaryClassificationEvaluator.from_input_examples(
dev_examples, batch_size=batch_size
)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=epochs,
warmup_steps=100,
evaluator=evaluator,
output_path=output,
evaluation_steps=100,
)