-
Notifications
You must be signed in to change notification settings - Fork 12
/
evaluate_learned_policy.py
107 lines (74 loc) · 2.97 KB
/
evaluate_learned_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gym
import json
import datetime as dt
#from stable_baselines3.common.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
from stable_baselines3 import PPO, A2C
import argparse
from power_square_gym import EnvRLAM as powersquareEnvRLAM
from velocity_square_gym import EnvRLAM as velocitysquareEnvRLAM
from power_triangle_gym import EnvRLAM as powertriangleEnvRLAM
from velocity_triangle_gym import EnvRLAM as velocitytriangleEnvRLAM
from stable_baselines3.ppo import MlpPolicy, CnnPolicy
from stable_baselines3.common.cmd_util import make_vec_env
import numpy as np
import torch as th
import matplotlib.pyplot as plt
import os
import sys
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import TD3
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnvWrapper
import numpy as np
import time
from collections import deque
import os.path as osp
import csv
def make_env(env_id, rank, seed=0):
def _init():
return env_id
return _init
def parse_arguments():
parser = argparse.ArgumentParser()
#parser = parser.add_mutually_exclusive_group(required=False)
parser.add_argument('--path', dest='path',
default='triangle',
help="Which scan path to use, options are \'square\' and \'triangle \'")
parser.add_argument('--param', dest='param', default = 'velocity',
help="Which control parameter to vary, options are \'velocity\' and \'power \' ")
parser.set_defaults(debug=False, param = 'velocity')
return parser.parse_args()
def main():
args = parse_arguments()
path = args.path
parameter = args.param
model_filename = input("Enter filename of model zip file: ")
if path == 'square':
if parameter == 'power':
env = powersquareEnvRLAM(plot = True, frameskip= 1)
elif parameter == 'velocity':
env = velocitysquareEnvRLAM(plot = True, frameskip= 1)
if path == 'triangle':
if parameter == 'power':
env = powertriangleEnvRLAM(plot = True, frameskip= 2)
elif parameter == 'velocity':
env = velocitytriangleEnvRLAM(plot = True, frameskip= 1)
model = PPO.load(model_filename)
num_cpu = 1
obs = env.reset()
c = 0
print("Evaluating model...")
while True:
c = c+ 1
action, _states = model.predict(obs, deterministic = True)
obs, rewards, dones, info = env.step(action)
if np.any(dones) == True:
break
if __name__ == "__main__":
main()