-
Notifications
You must be signed in to change notification settings - Fork 3
/
merge_comment_emb.py
81 lines (59 loc) · 2.11 KB
/
merge_comment_emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import math
import pandas as pd
import os
import json
import numpy as np
import torch
from tqdm.notebook import tqdm
import pickle
from arguments import ModelArguments, DataArguments, TrainingArguments
from model.model import RecComModel
def load_json(file):
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
return data
def load_pkl(file):
with open(file, 'rb') as file:
data = pickle.load(file)
return data
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--emb_size", type=int, default=256)
parser.add_argument("--file_prefix", type=str, default="comment_embs_")
parser.add_argument("--output_file", type=str, default="comment_embs.npy")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
print(vars(args))
model_args = ModelArguments()
data_args = DataArguments()
data_path = data_args.data_path
print(data_path)
all_comments = load_pkl(os.path.join(data_path, "all_comments.pkl"))
comment2id = {str(comment): i for i, comment in enumerate(all_comments)}
n_comments = len(all_comments)
print("n_comments", n_comments)
print(all_comments[:10])
file_list = os.listdir(data_path)
file_list = [os.path.join(data_path, file) for file in file_list if
file.startswith(args.file_prefix) and file.endswith(".pkl")]
for f in file_list:
print(f)
print("file num:", len(file_list))
all_emb = np.zeros((n_comments, args.emb_size), dtype=np.float32)
_all_comments = set()
_all_comments.add('[PAD]')
for file in file_list:
print(file)
embs = load_pkl(file)
comment_list = list(embs.keys())
for com in comment_list:
_all_comments.add(com)
id = comment2id[com]
all_emb[id] = np.nan_to_num(embs.pop(com).astype(np.float32),
nan=0.0, posinf=0.0, neginf=0.0)
del embs
print(len(all_comments))
print(len(list(_all_comments)))
np.save(os.path.join(data_path, args.output_file), all_emb)