Skip to content

Commit

Permalink
fix bugs: adjust shapes, init optimizer (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 4, 2023
1 parent ecc5066 commit 359c12b
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public PPOModel(PPOTrainingSettings config)
featureCache = Matrix2D.Zeros(config.BatchSize, config.NumStateDims);
strategyOpt = new AdamOpt(config.LearnRate);
valueFuncOpt = new AdamOpt(config.LearnRate);
strategyOpt.Compile(strategy.GradsTape);
valueFuncOpt.Compile(valueFunc.GradsTape);
}

private PPOTrainingSettings config;
Expand Down Expand Up @@ -383,18 +385,18 @@ private void shuffleDataset()

private void cacheGAE(PPOTrainBatch cache)
{
var nonterm_t1 = Matrix2D.Zeros(1, NumEnvs);
var lambda = Matrix2D.Zeros(1, NumEnvs);
var delta = Matrix2D.Zeros(1, NumEnvs);
var nonterm_t1 = Matrix2D.Zeros(NumEnvs, 1);
var lambda = Matrix2D.Zeros(NumEnvs, 1);
var delta = Matrix2D.Zeros(NumEnvs, 1);

for (int t = Steps - 1; t >= 0; t--)
{
var r_t0 = cache.Rewards.SliceRows(t, 1);
var term_t1 = cache.Terminals.SliceRows(t+1, 1);
var v_t0 = cache.OldBaselines.SliceRows(t, 1);
var v_t1 = cache.OldBaselines.SliceRows(t+1, 1);
var A_t0 = cache.Advantages.SliceRows(t, 1);
var G_t0 = cache.Returns.SliceRows(t, 1);
var r_t0 = cache.Rewards.SliceRows(t, NumEnvs);
var term_t1 = cache.Terminals.SliceRows(t+1, NumEnvs);
var v_t0 = cache.OldBaselines.SliceRows(t, NumEnvs);
var v_t1 = cache.OldBaselines.SliceRows(t+1, NumEnvs);
var A_t0 = cache.Advantages.SliceRows(t, NumEnvs);
var G_t0 = cache.Returns.SliceRows(t, NumEnvs);

Matrix2D.BatchMul(term_t1, -1, nonterm_t1);
Matrix2D.BatchAdd(nonterm_t1, 1, nonterm_t1);
Expand Down

0 comments on commit 359c12b

Please sign in to comment.