-
Notifications
You must be signed in to change notification settings - Fork 60
/
validate.py
43 lines (41 loc) · 1.51 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gym
import env
import PPO_model
import torch
import time
import os
import copy
def get_validate_env(env_paras):
'''
Generate and return the validation environment from the validation set ()
'''
file_path = "./data_dev/{0}{1}/".format(env_paras["num_jobs"], str.zfill(str(env_paras["num_mas"]),2))
valid_data_files = os.listdir(file_path)
for i in range(len(valid_data_files)):
valid_data_files[i] = file_path+valid_data_files[i]
env = gym.make('fjsp-v0', case=valid_data_files, env_paras=env_paras, data_source='file')
return env
def validate(env_paras, env, model_policy):
'''
Validate the policy during training, and the process is similar to test
'''
start = time.time()
batch_size = env_paras["batch_size"]
memory = PPO_model.Memory()
print('There are {0} dev instances.'.format(batch_size)) # validation set is also called development set
state = env.state
done = False
dones = env.done_batch
while ~done:
with torch.no_grad():
actions = model_policy.act(state, memory, dones, flag_sample=False, flag_train=False)
state, rewards, dones = env.step(actions)
done = dones.all()
gantt_result = env.validate_gantt()[0]
if not gantt_result:
print("Scheduling Error!!!!!!")
makespan = copy.deepcopy(env.makespan_batch.mean())
makespan_batch = copy.deepcopy(env.makespan_batch)
env.reset()
print('validating time: ', time.time() - start, '\n')
return makespan, makespan_batch