-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_mcts.py
32 lines (27 loc) · 855 Bytes
/
run_mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from model import MonteCarloTreeSearch
import random
from gym.envs.registration import register
import gym
from utilities.tree import Tree
def init_env():
register(
id='FrozenLakeNotSlippery-v0',
entry_point='gym.envs.toy_text:FrozenLakeEnv',
kwargs={'map_name': '4x4', 'is_slippery': False}
)
return gym.make('FrozenLakeNotSlippery-v0')
def main():
random.seed(2)
env = init_env()
tree = Tree()
monteCarloTreeSearch = MonteCarloTreeSearch(env=env, tree=tree)
steps = 10000
for _ in range(0, steps):
env.reset()
node = monteCarloTreeSearch.tree_policy()
reward = monteCarloTreeSearch.default_policy(node)
monteCarloTreeSearch.backward(node, reward)
monteCarloTreeSearch.tree.show()
monteCarloTreeSearch.forward()
if __name__ == "__main__":
main()