Skip to content

Commit

Permalink
fix memory leak, make train loop run
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 5, 2023
1 parent 359c12b commit 8b3446b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
8 changes: 4 additions & 4 deletions Schafkopf.Training/Algos/MDP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ public void Collect(PPORolloutBuffer buffer)
// throw new ArgumentException("The number of steps needs to be "
// + "divisible by 8 because each agent plays 8 cards per game!");

Console.Write($"collect data");

int numGames = buffer.Steps / 8;
int numSessions = buffer.NumEnvs / 4;
var envs = Enumerable.Range(0, numSessions)
Expand All @@ -34,13 +32,15 @@ public void Collect(PPORolloutBuffer buffer)

for (int gameId = 0; gameId < numGames + 1; gameId++)
{
Console.Write($"\rcollecting ppo training data { gameId+1 } / { numGames } ... ");
Console.Write($"\rcollecting ppo training data { gameId+1 } / { numGames+1 } ... ");
playGame(envs, states, batchesOfTurns);
prepareRewards(states, rewards);
fillBuffer(gameId, buffer, states, batchesOfTurns, rewards);
for (int i = 0; i < states.Length; i++)
states[i] = envs[i].Reset();
}

Console.WriteLine();
}

private void fillBuffer(
Expand Down Expand Up @@ -153,7 +153,7 @@ public TurnBatches(int numSessions)
piBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
piSparseBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
vBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
}
Expand Down
48 changes: 27 additions & 21 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public PPOModel Train(PPOTrainingSettings config)

if ((ep + 1) % 10 == 0)
{
model.RecompileCache(batchSize: 1);
double winRate = benchmark.Benchmark(agent);
model.RecompileCache(batchSize: config.BatchSize);
Console.WriteLine($"epoch {ep}: win rate vs. random agents is {winRate}");
}
}
Expand Down Expand Up @@ -126,6 +128,7 @@ public PPOModel(PPOTrainingSettings config)
private IOptimizer strategyOpt;
private IOptimizer valueFuncOpt;
private Matrix2D featureCache;
private ILoss mse = new MeanSquaredError();

public int BatchSize => config.BatchSize;

Expand All @@ -139,25 +142,33 @@ public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV)

public void Train(PPORolloutBuffer memory)
{
int numBatches = memory.NumBatches(
config.BatchSize, config.UpdateEpochs);
var batches = memory.SampleDataset(
config.BatchSize, config.UpdateEpochs);

int i = 1;
foreach (var batch in batches)
{
Console.Write($"\rtraining {i++} / {numBatches}");
updateModels(batch);
}
Console.WriteLine();
}

private void updateModels(PPOTrainBatch batch)
{
// update strategy pi(s)
var predPi = strategy.PredictBatch(batch.StatesBefore);
var policyDeltas = strategy.Layers.Last().Cache.DeltasIn;
computePolicyDeltas(batch, predPi, policyDeltas);
strategy.FitBatch(policyDeltas, strategyOpt);
var strategyDeltas = strategy.Layers.Last().Cache.DeltasIn;
computePolicyDeltas(batch, predPi, strategyDeltas);
strategy.FitBatch(strategyDeltas, strategyOpt);

// update baseline V(s)
var predV = valueFunc.PredictBatch(batch.StatesBefore);
var valueDeltas = valueFunc.Layers.Last().Cache.DeltasIn;
computeValueDeltas(batch, predV, valueDeltas);
mse.LossDeltas(predV, batch.Returns, valueDeltas);
// TODO: add value clipping
valueFunc.FitBatch(valueDeltas, valueFuncOpt);
}

Expand All @@ -182,7 +193,8 @@ private void computePolicyDeltas(
advantages = normAdvantages;
}

var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32));
var onehots = onehotIndices(batch.Actions, config.NumActionDims)
.Zip(Enumerable.Range(0, config.NumActionDims));
foreach ((int p, int i) in onehots)
unsafe { newProbs.Data[i] = predPi.Data[p]; }
Matrix2D.BatchAdd(batch.OldProbs, 1e-8, policyRatios);
Expand All @@ -203,26 +215,17 @@ private void computePolicyDeltas(
unsafe { policyDeltas.Data[p] = policyDeltasSparse.Data[i]; }
}

private void computeValueDeltas(
PPOTrainBatch batch, Matrix2D predV, Matrix2D valueDeltas)
{
var mse = new MeanSquaredError();
var valueDeltasSparse = Matrix2D.Zeros(batch.Size, 1);
mse.LossDeltas(predV, batch.Returns, valueDeltasSparse);

// TODO: add value clipping

Matrix2D.BatchMul(valueDeltas, 0, valueDeltas);
var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32));
foreach ((int p, int i) in onehots)
unsafe { valueDeltas.Data[p] = valueDeltasSparse.Data[i]; }
}

private IEnumerable<int> onehotIndices(Matrix2D sparseClassIds, int numClasses)
{
for (int i = 0; i < sparseClassIds.NumRows; i++)
yield return i * numClasses + (int)sparseClassIds.At(0, i);
}

public void RecompileCache(int batchSize)
{
strategy.RecompileCache(batchSize);
valueFunc.RecompileCache(batchSize);
}
}

public class PossibleCardPicker
Expand Down Expand Up @@ -356,6 +359,9 @@ public PPORolloutBuffer(PPOTrainingSettings config)

public bool IsReadyForModelUpdate(int t) => t > 0 && t % Steps == 0;

public int NumBatches(int batchSize, int epochs = 1)
=> cacheWithoutLastStep.Size / batchSize * epochs;

public PPOTrainBatch? SliceStep(int t)
{
int offset = IsReadyForModelUpdate(t)
Expand All @@ -374,7 +380,7 @@ public IEnumerable<PPOTrainBatch> SampleDataset(int batchSize, int epochs = 1)
yield return batch;
}

copyOverlappingStep();
// copyOverlappingStep();
}

private void shuffleDataset()
Expand Down
16 changes: 11 additions & 5 deletions Schafkopf.Training/NeuralNet/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@ public FFModel(IList<ILayer> layers)

public void Compile(int batchSize, int inputDims)
{
BatchSize = batchSize;
GradsTape = new List<Matrix2D>();
var input = Matrix2D.Zeros(batchSize, inputDims);
var deltaOut = Matrix2D.Zeros(batchSize, inputDims);

foreach (var layer in Layers)
{
layer.Compile(inputDims);
inputDims = layer.OutputDims;
}

RecompileCache(batchSize);
}

public void RecompileCache(int batchSize)
{
BatchSize = batchSize;
GradsTape = new List<Matrix2D>();
int inputDims = Layers.First().InputDims;
var input = Matrix2D.Zeros(batchSize, inputDims);
var deltaOut = Matrix2D.Zeros(batchSize, inputDims);

foreach (var layer in Layers)
{
layer.CompileCache(input, deltaOut);
Expand Down

0 comments on commit 8b3446b

Please sign in to comment.