Skip to content

Commit

Permalink
Fixed Baseline loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
masus04 committed Sep 14, 2018
1 parent 197147e commit e579151
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def run(self):
experiment.reset()
"""


if __name__ == '__main__':

# Player selection
Expand Down
31 changes: 12 additions & 19 deletions TicTacToe/players/baselinePlayer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from numba import jit
import numpy as np

import TicTacToe.config as config
from models import FCPolicyModel, LargeFCPolicyModel, HugeFCPolicyModel, ConvPolicyModel
Expand Down Expand Up @@ -55,15 +53,11 @@ def update(self):

# ----------------------------------------------------------- #

# Use either Discount (and baseline),
rewards = self.discount_rewards(self.rewards, self.gamma)
rewards = self.rewards_baseline(rewards)
# Or Bootstrapping
# rewards = self.bootstrap_rewards()
rewards = config.make_variable(rewards)
# rewards = self.normalize_rewards(rewards) # For now nothing to normalize, standard deviation = 0

loss = calculate_loss(self.log_probs, self.state_values, rewards)
loss = self.calculate_loss(self.log_probs, self.state_values, rewards)

self.optimizer.zero_grad()
loss.backward()
Expand All @@ -77,6 +71,17 @@ def update(self):

return abs(loss.data)

@staticmethod
def calculate_loss(log_probs, state_values, rewards):
policy_losses = []
value_losses = []

for log_prob, state_value, reward in zip(log_probs, state_values, rewards):
policy_losses.append(-log_prob * reward)
value_losses.append(F.smooth_l1_loss(state_value, reward))

return torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()


INTERMEDIATE_SIZE = 9*8

Expand All @@ -103,15 +108,3 @@ class ConvBaseLinePlayer(LearningPlayer):
def __init__(self, lr=config.LR, strategy=None, weight_decay=0.003):
super(ConvBaseLinePlayer, self).__init__(strategy=strategy if strategy is not None
else BaselinePGStrategy(lr, weight_decay=weight_decay, model=ConvPolicyModel(config=config, intermediate_size=INTERMEDIATE_SIZE)))


# @jit
def calculate_loss(log_probs, state_values, rewards):
policy_losses = []
value_losses = []

for log_prob, state_value, reward in zip(log_probs, state_values, rewards):
policy_losses.append(-log_prob * (reward - state_value.data))
value_losses.append(F.smooth_l1_loss(state_value, reward))

return torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

0 comments on commit e579151

Please sign in to comment.