Skip to content

Commit

Permalink
finish ppo implementation (to be tested)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 22, 2023
1 parent 36475f7 commit 9946b3e
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 119 deletions.
209 changes: 158 additions & 51 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

// TODO: train a policy to predict the likelihood
// of selecting an action in a given state
namespace Schafkopf.Training;

public class UniformDistribution
{
Expand All @@ -25,10 +23,9 @@ public int Sample(ReadOnlySpan<double> probs)

public class PPOAgent : ISchafkopfAIAgent
{
public PPOAgent(PPOTrainingSettings config, Action<ACSarsExp> expConsumer)
public PPOAgent(PPOTrainingSettings config)
{
this.config = config;
this.expConsumer = expConsumer;

valueFunc = new FFModel(new ILayer[] {
new DenseLayer(64),
Expand All @@ -49,22 +46,19 @@ public PPOAgent(PPOTrainingSettings config, Action<ACSarsExp> expConsumer)

valueFunc.Compile(config.BatchSize, 90);
strategy.Compile(config.BatchSize, 90);

trainExpCache = new SarsExp[config.BatchSize];
for (int i = 0; i < config.BatchSize; i++)
trainExpCache[i] = new SarsExp() {
StateBefore = new GameState(),
StateAfter = new GameState() };
featureCache = Matrix2D.Zeros(config.BatchSize, 90);
strategyOpt = new AdamOpt(config.LearnRate);
valueFuncOpt = new AdamOpt(config.LearnRate);
}

private PPOTrainingSettings config;
private Action<ACSarsExp> expConsumer;
private FFModel valueFunc;
private FFModel strategy;
private IOptimizer strategyOpt;
private IOptimizer valueFuncOpt;

private Matrix2D featureCache;
private GameStateSerializer stateSerializer = new GameStateSerializer();
private Matrix2D featureCache = Matrix2D.Zeros(1, 90);
private SarsExp[] trainExpCache;
private UniformDistribution uniform = new UniformDistribution();

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
Expand All @@ -73,16 +67,29 @@ public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
var s0 = stateSerializer.SerializeState(log);
unsafe { s0.ExportFeatures(x.Data);}

var predPi = strategy.PredictBatch(featureCache);
var predQ = valueFunc.PredictBatch(featureCache);
var predPi = strategy.PredictBatch(x);
var probDist = normProbDist(predPi, possibleCards);
int cardId = uniform.Sample(probDist);
int i = uniform.Sample(probDist);
var action = possibleCards[i];
return action;
}

expConsumer(new ACSarsExp() {
public (Card, double, double) Predict(
GameLog log, ReadOnlySpan<Card> possibleCards)
{
var x = featureCache;
var s0 = stateSerializer.SerializeState(log);
unsafe { s0.ExportFeatures(x.Data);}

});
var predPi = strategy.PredictBatch(x);
var predV = valueFunc.PredictBatch(x);
var probDist = normProbDist(predPi, possibleCards);
int i = uniform.Sample(probDist);

return possibleCards[cardId];
var action = possibleCards[i];
var pi = predPi.At(0, action.Id);
var v = predV.At(0, action.Id);
return (action, pi, v);
}

private double[] probDistCache = new double[8];
Expand All @@ -104,20 +111,89 @@ private ReadOnlySpan<double> normProbDist(
return probDistCache.AsSpan().Slice(0, possibleCards.Length);
}

public void Train(ReplayMemory memory)
public void Train(PPORolloutBuffer memory)
{
int numBatches = memory.Size / trainExpCache.Length;
var batches = memory.SampleDataset(
config.BatchSize, config.UpdateEpochs);

for (int i = 0; i < numBatches; i++)
foreach (var batch in batches)
updateModels(batch);
}

private void updateModels(PPOTrainBatch batch)
{
var predPi = strategy.PredictBatch(batch.StatesBefore);
var policyDeltas = strategy.Layers.Last().Cache.DeltasIn;
computePolicyDeltas(batch, predPi, policyDeltas);
strategy.FitBatch(policyDeltas, strategyOpt);

var predV = valueFunc.PredictBatch(batch.StatesBefore);
var valueDeltas = valueFunc.Layers.Last().Cache.DeltasIn;
computeValueDeltas(batch, predV, valueDeltas);
valueFunc.FitBatch(valueDeltas, valueFuncOpt);
}

private void computePolicyDeltas(
PPOTrainBatch batch, Matrix2D predPi, Matrix2D policyDeltas)
{
var normAdvantages = Matrix2D.Zeros(batch.Size, 1);
var policyRatios = Matrix2D.Zeros(batch.Size, 1);
var derPolicyRatios = Matrix2D.Zeros(batch.Size, 1);
var newProbs = Matrix2D.Zeros(batch.Size, 1);
var derNewProbs = Matrix2D.Zeros(batch.Size, 1);
var clipMask = Matrix2D.Zeros(batch.Size, 1);
var policyDeltasSparse = Matrix2D.Zeros(batch.Size, 1);

var advantages = batch.Advantages;
if (config.NormAdvantages)
{
memory.SampleBatched(trainExpCache);
updateModels(trainExpCache);
double mean = Matrix2D.Mean(batch.Advantages);
double stdDev = Matrix2D.StdDev(batch.Advantages);
Matrix2D.BatchSub(batch.Advantages, mean, normAdvantages);
Matrix2D.BatchDiv(batch.Advantages, stdDev, normAdvantages);
advantages = normAdvantages;
}

var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32));
foreach ((int p, int i) in onehots)
unsafe { newProbs.Data[i] = predPi.Data[p]; }
Matrix2D.BatchAdd(batch.OldProbs, 1e-8, policyRatios);
Matrix2D.ElemDiv(policyRatios, newProbs, policyRatios);
Matrix2D.BatchOneOver(derNewProbs, derPolicyRatios);
Matrix2D.ElemMul(policyRatios, derPolicyRatios, derPolicyRatios);

Matrix2D.ElemGeq(policyRatios, 1 + config.ProbClip, clipMask);
Matrix2D.ElemLeq(policyRatios, 1 - config.ProbClip, clipMask);
Matrix2D.ElemNeq(clipMask, 1, clipMask);

Matrix2D.ElemMul(clipMask, derPolicyRatios, policyDeltasSparse);
Matrix2D.ElemMul(policyDeltasSparse, advantages, policyDeltasSparse);
Matrix2D.BatchMul(policyDeltasSparse, -1, policyDeltasSparse);

Matrix2D.BatchMul(policyDeltas, 0, policyDeltas);
foreach ((int p, int i) in onehots)
unsafe { policyDeltas.Data[p] = policyDeltasSparse.Data[i]; }
}

private void updateModels(ReadOnlySpan<SarsExp> expsBatch)
private void computeValueDeltas(
PPOTrainBatch batch, Matrix2D predV, Matrix2D valueDeltas)
{
var mse = new MeanSquaredError();
var valueDeltasSparse = Matrix2D.Zeros(batch.Size, 1);
mse.LossDeltas(predV, batch.Returns, valueDeltasSparse);

// TODO: add value clipping

Matrix2D.BatchMul(valueDeltas, 0, valueDeltas);
var onehots = onehotIndices(batch.Actions, 32).Zip(Enumerable.Range(0, 32));
foreach ((int p, int i) in onehots)
unsafe { valueDeltas.Data[p] = valueDeltasSparse.Data[i]; }
}

private IEnumerable<int> onehotIndices(Matrix2D sparseClassIds, int numClasses)
{
for (int i = 0; i < sparseClassIds.NumRows; i++)
yield return i * numClasses + (int)sparseClassIds.At(0, i);
}

public void OnGameFinished(GameLog final)
Expand Down Expand Up @@ -234,14 +310,14 @@ public IEnumerable<PPOTrainBatch> SampleBatched(int batchSize)
var trainBuf = new PPOTrainBatch() { Size = batchSize };
for (int i = 0; i < numBatches; i++)
{
trainBuf.StatesBefore = Matrix2D.SliceRows(StatesBefore, p, batchSize);
trainBuf.Actions = Matrix2D.SliceRows(Actions, p, batchSize);
trainBuf.Rewards = Matrix2D.SliceRows(Rewards, p, batchSize);
trainBuf.Terminals = Matrix2D.SliceRows(Terminals, p, batchSize);
trainBuf.Returns = Matrix2D.SliceRows(Returns, p, batchSize);
trainBuf.Advantages = Matrix2D.SliceRows(Advantages, p, batchSize);
trainBuf.OldProbs = Matrix2D.SliceRows(OldProbs, p, batchSize);
trainBuf.OldBaselines = Matrix2D.SliceRows(OldBaselines, p, batchSize);
trainBuf.StatesBefore = StatesBefore.SliceRows(p, batchSize);
trainBuf.Actions = Actions.SliceRows(p, batchSize);
trainBuf.Rewards = Rewards.SliceRows(p, batchSize);
trainBuf.Terminals = Terminals.SliceRows(p, batchSize);
trainBuf.Returns = Returns.SliceRows(p, batchSize);
trainBuf.Advantages = Advantages.SliceRows(p, batchSize);
trainBuf.OldProbs = OldProbs.SliceRows(p, batchSize);
trainBuf.OldBaselines = OldBaselines.SliceRows(p, batchSize);
yield return trainBuf;
p += batchSize;
}
Expand All @@ -250,14 +326,14 @@ public IEnumerable<PPOTrainBatch> SampleBatched(int batchSize)
public PPOTrainBatch SliceRows(int rowid, int length)
=> new PPOTrainBatch {
Size = length,
StatesBefore = Matrix2D.SliceRows(StatesBefore, rowid, length),
Actions = Matrix2D.SliceRows(Actions, rowid, length),
Rewards = Matrix2D.SliceRows(Rewards, rowid, length),
Terminals = Matrix2D.SliceRows(Terminals, rowid, length),
Returns = Matrix2D.SliceRows(Returns, rowid, length),
Advantages = Matrix2D.SliceRows(Advantages, rowid, length),
OldProbs = Matrix2D.SliceRows(OldProbs, rowid, length),
OldBaselines = Matrix2D.SliceRows(OldBaselines, rowid, length)
StatesBefore = StatesBefore.SliceRows(rowid, length),
Actions = Actions.SliceRows(rowid, length),
Rewards = Rewards.SliceRows(rowid, length),
Terminals = Terminals.SliceRows(rowid, length),
Returns = Returns.SliceRows(rowid, length),
Advantages = Advantages.SliceRows(rowid, length),
OldProbs = OldProbs.SliceRows(rowid, length),
OldBaselines = OldBaselines.SliceRows(rowid, length)
};
}

Expand Down Expand Up @@ -302,24 +378,55 @@ public void AppendStep(ACSarsExp[] expsOfStep, int t)
cache.WriteRow(expsOfStep[i], offset + i);
}

public IEnumerable<PPOTrainBatch> SampleDataset(int batchSize)
public IEnumerable<PPOTrainBatch> SampleDataset(int batchSize, int epochs = 1)
{
cacheGAE();

Perm.Permutate(permCache);
cacheWithoutLastStep.Shuffle(permCache);
cacheGAE(cache);

foreach(var batch in cacheWithoutLastStep.SampleBatched(batchSize))
yield return batch;
for (int i = 0; i < epochs; i++)
{
shuffleDataset();
foreach(var batch in cacheWithoutLastStep.SampleBatched(batchSize))
yield return batch;
}

copyOverlappingStep();
}

private void cacheGAE()
private void shuffleDataset()
{
Perm.Permutate(permCache);
cacheWithoutLastStep.Shuffle(permCache);
}

private void cacheGAE(PPOTrainBatch cache)
{
var nonterm_t1 = Matrix2D.Zeros(1, numEnvs);
var lambda = Matrix2D.Zeros(1, numEnvs);
var delta = Matrix2D.Zeros(1, numEnvs);

for (int t = steps - 1; t >= 0; t--)
{
// TODO: do the GAE magic
var r_t0 = cache.Rewards.SliceRows(t, 1);
var term_t1 = cache.Terminals.SliceRows(t+1, 1);
var v_t0 = cache.OldBaselines.SliceRows(t, 1);
var v_t1 = cache.OldBaselines.SliceRows(t+1, 1);
var A_t0 = cache.Advantages.SliceRows(t, 1);
var G_t0 = cache.Returns.SliceRows(t, 1);

Matrix2D.BatchMul(term_t1, -1, nonterm_t1);
Matrix2D.BatchAdd(nonterm_t1, 1, nonterm_t1);

Matrix2D.ElemMul(v_t1, nonterm_t1, delta);
Matrix2D.BatchMul(delta, gamma, delta);
Matrix2D.ElemAdd(delta, r_t0, delta);
Matrix2D.ElemSub(delta, v_t0, delta);

Matrix2D.ElemMul(nonterm_t1, lambda, lambda);
Matrix2D.BatchMul(lambda, gamma * gaeGamma, lambda);
Matrix2D.ElemAdd(lambda, delta, lambda);

Matrix2D.CopyData(lambda, A_t0);
Matrix2D.ElemAdd(v_t0, A_t0, G_t0);
}
}

Expand Down
59 changes: 0 additions & 59 deletions Schafkopf.Training/NeuralNet/Layers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,62 +207,3 @@ public void ApplyGrads()
// info: layer has no trainable params
}
}

// public class FlattenLayer : ILayer
// {
// public FlattenLayer(int axis = 1)
// => this.axis = axis;

// private int axis;

// public LayerCache Cache { get; private set; }

// public int InputDims { get; private set; }

// public int OutputDims { get; private set; }

// public void Compile(int inputDims)
// {
// InputDims = inputDims;
// }

// public void CompileCache(Matrix2D inputs, Matrix2D deltasOut)
// {
// int flatDims = inputs.NumRows * inputs.NumCols;
// OutputDims = axis == 0 ? flatDims : 1;
// int batchSize = axis == 0 ? 1 : flatDims;

// Cache = new LayerCache() {
// Input = inputs,
// Output = Matrix2D.Zeros(batchSize, OutputDims),
// DeltasIn = Matrix2D.Zeros(batchSize, OutputDims),
// DeltasOut = deltasOut,
// Gradients = Matrix2D.Null()
// };
// }

// public void Forward()
// {
// unsafe
// {
// int dataLen = Cache.Input.NumRows * Cache.Input.NumCols;
// for (int i = 0; i < dataLen; i++)
// Cache.Output.Data[i] = Cache.Input.Data[i];
// }
// }

// public void Backward()
// {
// unsafe
// {
// int dataLen = Cache.DeltasIn.NumRows * Cache.DeltasIn.NumCols;
// for (int i = 0; i < dataLen; i++)
// Cache.DeltasOut.Data[i] = Cache.DeltasIn.Data[i];
// }
// }

// public void ApplyGrads()
// {
// // info: layer has no trainable params
// }
// }
Loading

0 comments on commit 9946b3e

Please sign in to comment.