diff --git a/Schafkopf.Training/Algos/MDP.cs b/Schafkopf.Training/Algos/MDP.cs index aac5b97..49658cd 100644 --- a/Schafkopf.Training/Algos/MDP.cs +++ b/Schafkopf.Training/Algos/MDP.cs @@ -21,8 +21,6 @@ public void Collect(PPORolloutBuffer buffer) // throw new ArgumentException("The number of steps needs to be " // + "divisible by 8 because each agent plays 8 cards per game!"); - Console.Write($"collect data"); - int numGames = buffer.Steps / 8; int numSessions = buffer.NumEnvs / 4; var envs = Enumerable.Range(0, numSessions) @@ -34,13 +32,15 @@ public void Collect(PPORolloutBuffer buffer) for (int gameId = 0; gameId < numGames + 1; gameId++) { - Console.Write($"\rcollecting ppo training data { gameId+1 } / { numGames } ... "); + Console.Write($"\rcollecting ppo training data { gameId+1 } / { numGames+1 } ... "); playGame(envs, states, batchesOfTurns); prepareRewards(states, rewards); fillBuffer(gameId, buffer, states, batchesOfTurns, rewards); for (int i = 0; i < states.Length; i++) states[i] = envs[i].Reset(); } + + Console.WriteLine(); } private void fillBuffer( @@ -153,7 +153,7 @@ public TurnBatches(int numSessions) piBatches = Enumerable.Range(0, 4) .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); piSparseBatches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); + .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); vBatches = Enumerable.Range(0, 4) .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); } diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index 89577d1..4e9e791 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -43,7 +43,9 @@ public PPOModel Train(PPOTrainingSettings config) if ((ep + 1) % 10 == 0) { + model.RecompileCache(batchSize: 1); double winRate = benchmark.Benchmark(agent); + model.RecompileCache(batchSize: config.BatchSize); Console.WriteLine($"epoch {ep}: win rate vs. random agents is {winRate}"); } } @@ -126,6 +128,7 @@ public PPOModel(PPOTrainingSettings config) private IOptimizer strategyOpt; private IOptimizer valueFuncOpt; private Matrix2D featureCache; + private ILoss mse = new MeanSquaredError(); public int BatchSize => config.BatchSize; @@ -139,25 +142,33 @@ public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV) public void Train(PPORolloutBuffer memory) { + int numBatches = memory.NumBatches( + config.BatchSize, config.UpdateEpochs); var batches = memory.SampleDataset( config.BatchSize, config.UpdateEpochs); + int i = 1; foreach (var batch in batches) + { + Console.Write($"\rtraining {i++} / {numBatches}"); updateModels(batch); + } + Console.WriteLine(); } private void updateModels(PPOTrainBatch batch) { // update strategy pi(s) var predPi = strategy.PredictBatch(batch.StatesBefore); - var policyDeltas = strategy.Layers.Last().Cache.DeltasIn; - computePolicyDeltas(batch, predPi, policyDeltas); - strategy.FitBatch(policyDeltas, strategyOpt); + var strategyDeltas = strategy.Layers.Last().Cache.DeltasIn; + computePolicyDeltas(batch, predPi, strategyDeltas); + strategy.FitBatch(strategyDeltas, strategyOpt); // update baseline V(s) var predV = valueFunc.PredictBatch(batch.StatesBefore); var valueDeltas = valueFunc.Layers.Last().Cache.DeltasIn; - computeValueDeltas(batch, predV, valueDeltas); + mse.LossDeltas(predV, batch.Returns, valueDeltas); + // TODO: add value clipping valueFunc.FitBatch(valueDeltas, valueFuncOpt); } @@ -182,7 +193,8 @@ private void computePolicyDeltas( advantages = normAdvantages; } - var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32)); + var onehots = onehotIndices(batch.Actions, config.NumActionDims) + .Zip(Enumerable.Range(0, config.NumActionDims)); foreach ((int p, int i) in onehots) unsafe { newProbs.Data[i] = predPi.Data[p]; } Matrix2D.BatchAdd(batch.OldProbs, 1e-8, policyRatios); @@ -203,26 +215,17 @@ private void computePolicyDeltas( unsafe { policyDeltas.Data[p] = policyDeltasSparse.Data[i]; } } - private void computeValueDeltas( - PPOTrainBatch batch, Matrix2D predV, Matrix2D valueDeltas) - { - var mse = new MeanSquaredError(); - var valueDeltasSparse = Matrix2D.Zeros(batch.Size, 1); - mse.LossDeltas(predV, batch.Returns, valueDeltasSparse); - - // TODO: add value clipping - - Matrix2D.BatchMul(valueDeltas, 0, valueDeltas); - var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32)); - foreach ((int p, int i) in onehots) - unsafe { valueDeltas.Data[p] = valueDeltasSparse.Data[i]; } - } - private IEnumerable onehotIndices(Matrix2D sparseClassIds, int numClasses) { for (int i = 0; i < sparseClassIds.NumRows; i++) yield return i * numClasses + (int)sparseClassIds.At(0, i); } + + public void RecompileCache(int batchSize) + { + strategy.RecompileCache(batchSize); + valueFunc.RecompileCache(batchSize); + } } public class PossibleCardPicker @@ -356,6 +359,9 @@ public PPORolloutBuffer(PPOTrainingSettings config) public bool IsReadyForModelUpdate(int t) => t > 0 && t % Steps == 0; + public int NumBatches(int batchSize, int epochs = 1) + => cacheWithoutLastStep.Size / batchSize * epochs; + public PPOTrainBatch? SliceStep(int t) { int offset = IsReadyForModelUpdate(t) @@ -374,7 +380,7 @@ public IEnumerable SampleDataset(int batchSize, int epochs = 1) yield return batch; } - copyOverlappingStep(); + // copyOverlappingStep(); } private void shuffleDataset() diff --git a/Schafkopf.Training/NeuralNet/Model.cs b/Schafkopf.Training/NeuralNet/Model.cs index 41959e9..4afd258 100644 --- a/Schafkopf.Training/NeuralNet/Model.cs +++ b/Schafkopf.Training/NeuralNet/Model.cs @@ -14,17 +14,23 @@ public FFModel(IList layers) public void Compile(int batchSize, int inputDims) { - BatchSize = batchSize; - GradsTape = new List(); - var input = Matrix2D.Zeros(batchSize, inputDims); - var deltaOut = Matrix2D.Zeros(batchSize, inputDims); - foreach (var layer in Layers) { layer.Compile(inputDims); inputDims = layer.OutputDims; } + RecompileCache(batchSize); + } + + public void RecompileCache(int batchSize) + { + BatchSize = batchSize; + GradsTape = new List(); + int inputDims = Layers.First().InputDims; + var input = Matrix2D.Zeros(batchSize, inputDims); + var deltaOut = Matrix2D.Zeros(batchSize, inputDims); + foreach (var layer in Layers) { layer.CompileCache(input, deltaOut);