Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #16 #17

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def _save_model(self, step, curr_reward):
self.logger.warning("Saved Model @ Step: " + str(step) + ": " + self.model_name + ".")

def _forward(self, observation):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def _backward(self, reward, terminal):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def _eval_model(self): # evaluation during training
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def fit_model(self): # training
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def test_model(self): # testing pre-trained models
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")
6 changes: 3 additions & 3 deletions core/agent_single_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def _ensure_global_grads(self):
global_param._grad = local_param.grad

def _forward(self, observation):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def _backward(self, reward, terminal):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def run(self):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")
12 changes: 6 additions & 6 deletions core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def _get_experience(self):
terminal1 = self.exp_terminal1)

def _preprocessState(self, state):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

@property
def state_shape(self):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

@property
def action_dim(self):
Expand All @@ -65,13 +65,13 @@ def action_dim(self):
return self.env.action_space.n

def render(self): # render using the original gl window
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def visual(self): # visualize onto visdom
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def reset(self):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def step(self, action):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")
4 changes: 2 additions & 2 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, args):
self.output_dims = args.action_dim

def _init_weights(self):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")

def print_model(self):
self.logger.warning("<-----------------------------------> Model")
Expand All @@ -43,4 +43,4 @@ def _reset(self): # NOTE: should be called at each child's __init__
self.print_model()

def forward(self, input):
raise NotImplementedError("not implemented in base calss")
raise NotImplementedError("not implemented in base class")