-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataStorage.py
283 lines (236 loc) · 9.27 KB
/
dataStorage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#!/usr/bin/env python3
import asyncio
import numpy as np
import os
import os.path
import pickle
import sys
import torch
import torch.utils.data
import torch.multiprocessing as mp
from batchgenerators.dataloading import SlimDataLoaderBase
from lockfile import LockFile
import config
#this manages the training data for deep cfr
#directory where samples are stored
DATA_DIR = config.dataDir
#DATA_DIR = './data/'
#whether to store data in memory or on disk
IN_MEMORY = config.inMemory
#whether to cache each sample in-memory on read
#this only has an effect when IN_MEMORY is False
BIG_CACHE = config.bigCache
#deletes the data from DATA_DIR
#does not delete the folder itself
def clearData():
os.system('rm -r ' + DATA_DIR + '*')
os.system('rm valloss.csv')
os.system('rm trainloss.csv')
os.system('rm stddev.csv')
#deletes data belonging to a certain name
def clearSamplesByName(name):
#remove the entry in index
if os.path.exists(DATA_DIR + 'index'):
target = -1
#find the line
with open(DATA_DIR + 'index', 'r') as file:
lines = list(file.readlines())
for i in range(len(lines)):
if lines[i].split(',')[0] == name:
target = i
break
#write out all lines except the targeted one
if target != -1:
del lines[i]
with open(DATA_DIR + 'index', 'w') as file:
for line in lines:
print(line, file=file, end='')
#delete the data files
os.system('rm -r ' + DATA_DIR + name + '/*')
#lock is a multiprocess manager lock
#id determines which dataset the samples belong to
#samples is a list of numpy arrays
#the nth sample will be written to data/id/n
#shared dict is used for any shared data (which will probably only include in-memory datasets)
def addSamples(lock, id, samples, sharedDict):
#write our count to the index file
#which must be thread-safe
#actually lock on the index file
lock = LockFile(DATA_DIR + 'index')
lock.acquire()
if IN_MEMORY:
if 'smp' + id not in sharedDict:
sharedDict['smp' + id] = samples
else:
old = sharedDict['smp' + id]
sharedDict['smp' + id] = old + samples
else:
lines = []
count = 0
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
#get the current number of samples
if os.path.exists(DATA_DIR + 'index'):
with open(DATA_DIR + 'index', 'r') as file:
lines = list(file.readlines())
for i in range(len(lines)):
line = lines[i]
if line.split(',')[0] == id:
count = int(line.split(',')[1][:-1])
lines[i] = id + ',' + str(count + len(samples)) + '\n'
break
#brand new sample set
if count == 0:
lines.append(id + ',' + str(len(samples)) + '\n')
if config.maxNumSamples[id] and count >= config.maxNumSamples[id]:
replace = True
#we're replacing existing files, so no need to update the index
else:
replace = False
#update the indices after we're written our files
with open(DATA_DIR + 'index', 'w') as file:
for line in lines:
print(line, file=file, end='')
#replacement means no amount of concurrency is safe here
#we really should support replacement of in-memory samples
if not IN_MEMORY:
if replace:
#pick which samples get removed
#we need to give old and new samples an equal chance so we don't introduce bias
writeIndices = np.random.choice(count + len(samples), len(samples), replace=False)
for i, sample in zip(writeIndices, samples):
#if we're supposed to replace a sample, then don't save that sample
if i < count:
with open(DATA_DIR + id + '/' + str(i), 'wb+') as file:
pickle.dump(sample, file)
else:
#write our each sample to its own file
if not os.path.exists(DATA_DIR + id):
os.mkdir(DATA_DIR + id)
for i in range(len(samples)):
with open(DATA_DIR + id + '/' + str(count + i), 'wb+') as file:
pickle.dump(samples[i], file)
lock.release()
def myCollate(batch):
#based on the collate_fn here
#https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/data_loader.py
#sort by data length
batch.sort(key=lambda x: len(x[0]), reverse=True)
data, labels, iters = zip(*batch)
#labels and iters have a fixed size, so we can just stack
labels = torch.stack(labels)
iters = torch.stack(iters)
#sequences are padded with 0 vectors to make the lengths the same
lengths = [len(d) for d in data]
padded = torch.zeros(len(data), max(lengths), len(data[0][0]), dtype=torch.long)
for i, d in enumerate(data):
end = lengths[i]
padded[i, :end] = d[:end]
#need to know the lengths so we can pack later
lengths = torch.tensor(lengths)
return padded, lengths, labels, iters
#following this
#https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/multithreaded_dataloading.py
class BatchDataLoader(SlimDataLoaderBase):
def __init__(self, id, indices, batch_size, num_threads_in_mt):
super(BatchDataLoader, self).__init__(None, batch_size, num_threads_in_mt)
self.id = id
self.size = len(indices)
self.indices = indices
self.current_position = 0
self.was_initialized = False
def reset(self):
self.current_position = self.thread_id
self.was_initialized = True
def shuffle(self):
np.random.shuffle(self.indices)
def generate_train_batch(self):
if not self.was_initialized:
self.reset()
batch = []
for i in range(self.batch_size):
if self.current_position < self.size:
idx = self.indices[self.current_position]
self.current_position += self.number_of_threads_in_multithreaded
with open(DATA_DIR + self.id + '/' + str(idx), 'rb') as file:
sample = pickle.load(file)
data, label, iter = sample
label = torch.from_numpy(label)
iter = torch.from_numpy(iter)
batch.append([data, label, iter])
elif len(batch) > 0:
return myCollate(batch)
else:
self.reset()
raise StopIteration
return myCollate(batch)
class Dataset(torch.utils.data.Dataset):
def __init__(self, id, sharedDict, outputSize):
#print('initing dataloader', file=sys.stderr)
self.id = id
if IN_MEMORY:
#print('initing in memory', file=sys.stderr)
self.sharedDict = sharedDict
if 'smp' + id not in sharedDict:
self.size = 0
self.samples = []
else:
self.samples = sharedDict['smp' + id]
self.size = self.samples.shape[0]
else:
#print('initing on disk', file=sys.stderr)
#self.sampleCache = {}
with open(DATA_DIR + 'index', 'r') as file:
for line in file.readlines():
if line.split(',')[0] == id:
self.size = int(line.split(',')[1][:-1])
def __getitem__(self, idx):
if IN_MEMORY:
#print('getting sample from memory', file=sys.stderr)
sample = self.samples[idx]
#print('got sample from memory', file=sys.stderr)
else:
#print('getting sample from disk', file=sys.stderr)
#if idx not in self.sampleCache:
with open(DATA_DIR + self.id + '/' + str(idx), 'rb') as file:
#self.sampleCache[idx] = np.load(file)
#sample = np.load(file)
sample = pickle.load(file)
#sample = self.sampleCache[idx]
#print('got sample from disk', file=sys.stderr)
#data = sample[0:modelInput.stateSize]
#label = sample[modelInput.stateSize:modelInput.stateSize + modelInput.numActions]
data, label, iter = sample
#data = sample[0:-(self.outputSize + 1)]
#data is a python list
#(except when it isn't)
#data = torch.from_numpy(data)
#label = sample[-(self.outputSize + 1):-1]
label = torch.from_numpy(label)
#iter = sample[-1:]
iter = torch.from_numpy(iter)
return data, label, iter
def __len__(self):
return self.size
def runner(lock, rank):
async def task(rank):
addSamples(lock, 'test', [np.array([rank] * 1000)])
loop = asyncio.get_event_loop()
loop.run_until_complete(task(rank))
def main():
print('starting')
m = mp.Manager()
lock = m.Lock()
processes = []
for rank in range(4):
p = mp.Process(target=runner, args=(lock, rank,))
p.start()
processes.append(p)
print('started')
print('waiting for processes to finish')
for p in processes:
p.join()
print('join')
print('done')
if __name__ == '__main__':
main()