-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrunner.py
62 lines (47 loc) · 2.13 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
import ray
import os
from attention_net import AttentionNet
from worker import Worker
from parameters import *
class Runner(object):
"""Actor object to start running simulation on workers.
Gradient computation is also executed on this object."""
def __init__(self, metaAgentID):
self.metaAgentID = metaAgentID
self.device = torch.device('cuda') if USE_GPU else torch.device('cpu')
self.localNetwork = AttentionNet(INPUT_DIM, EMBEDDING_DIM)
self.localNetwork.to(self.device)
def get_weights(self):
return self.localNetwork.state_dict()
def set_weights(self, weights):
self.localNetwork.load_state_dict(weights)
def singleThreadedJob(self, episodeNumber, budget_range, sample_size, sample_length):
save_img = True if (SAVE_IMG_GAP != 0 and episodeNumber % SAVE_IMG_GAP == 0) else False
#save_img = False
worker = Worker(self.metaAgentID, self.localNetwork, episodeNumber, budget_range, sample_size, sample_length, self.device, save_image=save_img, greedy=False)
worker.work(episodeNumber)
jobResults = worker.experience
perf_metrics = worker.perf_metrics
return jobResults, perf_metrics
def job(self, global_weights, episodeNumber, budget_range, sample_size=SAMPLE_SIZE, sample_length=None):
print("starting episode {} on metaAgent {}".format(episodeNumber, self.metaAgentID))
# set the local weights to the global weight values from the master network
self.set_weights(global_weights)
jobResults, metrics = self.singleThreadedJob(episodeNumber, budget_range, sample_size, sample_length)
info = {
"id": self.metaAgentID,
"episode_number": episodeNumber,
}
return jobResults, metrics, info
@ray.remote(num_cpus=1, num_gpus=len(CUDA_DEVICE)/NUM_META_AGENT)
class RLRunner(Runner):
def __init__(self, metaAgentID):
super().__init__(metaAgentID)
if __name__=='__main__':
ray.init()
runner = RLRunner.remote(0)
job_id = runner.singleThreadedJob.remote(1)
out = ray.get(job_id)
print(out[1])