diff --git a/Othello/players/reinforcePlayer.py b/Othello/players/reinforcePlayer.py index 701bcc3..8e6d206 100644 --- a/Othello/players/reinforcePlayer.py +++ b/Othello/players/reinforcePlayer.py @@ -43,7 +43,7 @@ def update(self): policy_losses = [(-log_prob * reward) for log_prob, reward in zip(self.log_probs, rewards)] self.optimizer.zero_grad() - policy_loss = torch.cat(policy_losses).sum()/len(policy_losses) + policy_loss = torch.mean(policy_losses) policy_loss.backward() self.optimizer.step()