From 359c12b857dc7fa403712e39fdd3bcde171ac4d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Mon, 4 Dec 2023 01:02:17 +0100 Subject: [PATCH] fix bugs: adjust shapes, init optimizer (wip) --- Schafkopf.Training/Algos/PPOAgent.cs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index 9d1dd75..89577d1 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -116,6 +116,8 @@ public PPOModel(PPOTrainingSettings config) featureCache = Matrix2D.Zeros(config.BatchSize, config.NumStateDims); strategyOpt = new AdamOpt(config.LearnRate); valueFuncOpt = new AdamOpt(config.LearnRate); + strategyOpt.Compile(strategy.GradsTape); + valueFuncOpt.Compile(valueFunc.GradsTape); } private PPOTrainingSettings config; @@ -383,18 +385,18 @@ private void shuffleDataset() private void cacheGAE(PPOTrainBatch cache) { - var nonterm_t1 = Matrix2D.Zeros(1, NumEnvs); - var lambda = Matrix2D.Zeros(1, NumEnvs); - var delta = Matrix2D.Zeros(1, NumEnvs); + var nonterm_t1 = Matrix2D.Zeros(NumEnvs, 1); + var lambda = Matrix2D.Zeros(NumEnvs, 1); + var delta = Matrix2D.Zeros(NumEnvs, 1); for (int t = Steps - 1; t >= 0; t--) { - var r_t0 = cache.Rewards.SliceRows(t, 1); - var term_t1 = cache.Terminals.SliceRows(t+1, 1); - var v_t0 = cache.OldBaselines.SliceRows(t, 1); - var v_t1 = cache.OldBaselines.SliceRows(t+1, 1); - var A_t0 = cache.Advantages.SliceRows(t, 1); - var G_t0 = cache.Returns.SliceRows(t, 1); + var r_t0 = cache.Rewards.SliceRows(t, NumEnvs); + var term_t1 = cache.Terminals.SliceRows(t+1, NumEnvs); + var v_t0 = cache.OldBaselines.SliceRows(t, NumEnvs); + var v_t1 = cache.OldBaselines.SliceRows(t+1, NumEnvs); + var A_t0 = cache.Advantages.SliceRows(t, NumEnvs); + var G_t0 = cache.Returns.SliceRows(t, NumEnvs); Matrix2D.BatchMul(term_t1, -1, nonterm_t1); Matrix2D.BatchAdd(nonterm_t1, 1, nonterm_t1);