forked from hanhaoy1/stgcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
99 lines (88 loc) · 3.75 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from data_loader import read_dataset, load_data
import numpy as np
import pandas as pd
from collections import defaultdict
import datetime
import geohash
import json
def poi2poi(data, days=7):
# data = data.groupby(['uid']).filter(lambda x: len(x) >= 10)
grouped = data.groupby(['uid']).apply(lambda x: x.sort_values('time'))
poi_pairs = defaultdict(int)
# timedelta = datetime.timedelta(hours=2)
timedelta = datetime.timedelta(days=days)
for idx, v in grouped.groupby(level=[0]):
times = v['time'].tolist()
for i in range(len(times)-1):
for j in range(i+1, len(times)):
if times[j] - times[i] < timedelta:
poi1 = v.iloc[i]['pid']
poi2 = v.iloc[j]['pid']
if poi1 < poi2:
poi_pairs[(poi1, poi2)] += 1
elif poi1 > poi2:
poi_pairs[(poi2, poi1)] += 1
else:
break
return poi_pairs
def region_neighbors(data, user_region=True):
region_pairs = set()
regions = set(data['region'])
if user_region:
regions = regions.union(set(data['user_region']))
for region in regions:
neighbors = geohash.neighbors(region)
for neighbor in neighbors:
if neighbor in regions:
if neighbor < region:
region_pairs.add((neighbor, region))
else:
region_pairs.add((region, neighbor))
return regions, region_pairs
def process(dataset, train_ratio=0.8, user_region=True):
data = load_data(dataset)
regions, region_pairs = region_neighbors(data, user_region=user_region)
users = set(data['uid'])
pois = set(data['pid'])
num_users = len(users)
num_pois = len(pois)
user2id = {v: k for k, v in enumerate(users)}
poi2id = {v: k + num_users for k, v in enumerate(pois)}
region2id = {v: k + num_users + num_pois for k, v in enumerate(regions)}
with open('./dataset/' + dataset + '/user2id.json', 'w') as f:
f.write(json.dumps(user2id))
with open('./dataset/' + dataset + '/poi2id.json', 'w') as f:
f.write(json.dumps(poi2id))
with open('./dataset/' + dataset + '/region2id.json', 'w') as f:
f.write(json.dumps(region2id))
data['uid'] = data['uid'].apply(lambda x: user2id[x])
data['pid'] = data['pid'].apply(lambda x: poi2id[x])
data['region'] = data['region'].apply(lambda x: region2id[x])
region_pairs = [[region2id[i], region2id[j]] for i, j in region_pairs]
if user_region:
data['user_region'] = data['user_region'].apply(lambda x: region2id[x])
msk = np.random.rand(len(data)) < train_ratio
train = data[msk]
test = data[~msk]
train.to_csv('./dataset/' + dataset + '/train.csv')
test.to_csv('./dataset/' + dataset + '/test.csv')
poi_pairs_train = poi2poi(train)
user_poi_train = train.groupby(['uid', 'pid']).size().reset_index(name='w')
poi_region = data[['pid', 'region']].drop_duplicates()
poi_region['w'] = 1
with open('./dataset/' + dataset + '/region2region.txt', 'w') as f:
for r1, r2 in region_pairs:
f.write(str(r1) + '\t' + str(r2) + '\t1\n')
with open('./dataset/' + dataset + '/poi2poi_train.txt', 'w') as f:
for k, v in poi_pairs_train.items():
p1 = str(k[0])
p2 = str(k[1])
w = str(v)
f.write(p1 + '\t' + p2 + '\t' + w + '\n')
path = './dataset/' + dataset + '/user_poi_train.txt'
user_poi_train.to_csv(path, sep='\t', header=False, index=False)
path = './dataset/' + dataset + '/poi_region.txt'
poi_region.to_csv(path, sep='\t', header=False, index=False)
if __name__ == '__main__':
# process('meituan')
process('gowalla', user_region=False)