Skip to content

Commit

Permalink
rework ppo agent (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 27, 2023
1 parent 29c9553 commit d71e7b9
Showing 1 changed file with 77 additions and 208 deletions.
285 changes: 77 additions & 208 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
namespace Schafkopf.Training;

public class PPOTrainingSession
{
public void Train()
{
// var config = new PPOTrainingSettings();
// var rewardFunc = new GameReward();
// var memory = new PPORolloutBuffer(config);
// var predCache = new PPOPredictionCache(config.NumEnvs, config.StepsPerUpdate);
// var ppoModel = new PPOModel(config);
// var cardPicker = new VectorizedCardPicker(config, ppoModel, predCache);
// var vecEnv = new VectorizedCardPickerEnv(cardPicker, config.BatchSize);
// var heuristicGameCaller = new HeuristicAgent();
// var envProxies = Enumerable.Range(0, config.NumEnvs)
// .Select(i => new EnvCardPicker(i, vecEnv)).ToArray();
// var agents = Enumerable.Range(0, config.NumEnvs)
// .Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray();

// var tables = Enumerable.Range(0, config.NumEnvs)
// .Select(i => new Table(
// new Player(0, agents[i]),
// new Player(1, agents[i]),
// new Player(2, agents[i]),
// new Player(3, agents[i])
// )).ToArray();
// var sessions = Enumerable.Range(0, config.NumEnvs)
// .Select(i => new GameSession(tables[i], new CardsDeck())).ToArray();

// for (int i = 0; i < 10_000; i++)
// {
// var games = sessions.AsParallel().Select(sess => sess.ProcessGame());
// }
}
}

public class UniformDistribution
{
public UniformDistribution(int? seed = null)
Expand Down Expand Up @@ -62,28 +96,12 @@ public PPOModel(PPOTrainingSettings config)
private Matrix2D featureCache;
private UniformDistribution uniform = new UniformDistribution();

public void Predict(
Matrix2D s0, Matrix2D outActions,
Matrix2D outPiOnehot, Matrix2D outV)
public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV)
{
int batchSize = s0.NumRows;
int numFeatures = s0.NumCols;

var predPi = strategy.PredictBatch(s0);
var predV = valueFunc.PredictBatch(s0);
Matrix2D.CopyData(predPi, outPiOnehot);

for (int i = 0; i < batchSize; i++)
{
unsafe
{
var probDist = new Span<double>(
predPi.Data + i * numFeatures, numFeatures);
int action = uniform.Sample(probDist);
outActions.Data[i] = action;
outV.Data[i] = predV.At(0, 0);
}
}
Matrix2D.CopyData(predV, outV);
}

public void Train(PPORolloutBuffer memory)
Expand Down Expand Up @@ -197,45 +215,6 @@ public class PPOTrainingSettings
public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots;
}

public interface ICardPicker
{
Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards);
}

public class ComposedAgent : ISchafkopfAIAgent
{
public ComposedAgent(
ISchafkopfAIAgent agent,
ICardPicker cardPicker)
{
this.agent = agent;
this.cardPicker = cardPicker;
}

private ISchafkopfAIAgent agent;
private ICardPicker cardPicker;

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
=> cardPicker.ChooseCard(log, possibleCards);

public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> agent.MakeCall(possibleCalls, position, hand, klopfer);

public bool CallKontra(GameLog log)
=> agent.CallKontra(log);

public bool CallRe(GameLog log)
=> agent.CallRe(log);

public bool IsKlopfer(int position, ReadOnlySpan<Card> firstFourCards)
=> agent.IsKlopfer(position, firstFourCards);

public void OnGameFinished(GameLog final)
=> agent.OnGameFinished(final);
}

public class PossibleCardPicker
{
private UniformDistribution uniform = new UniformDistribution();
Expand All @@ -247,6 +226,9 @@ public Card PickCard(
=> canPlaySampledCard(possibleCards, sampledCard) ? sampledCard
: possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

public Card PickCard(ReadOnlySpan<Card> possibleCards, ReadOnlySpan<double> predPi)
=> possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

private bool canPlaySampledCard(
ReadOnlySpan<Card> possibleCards, Card sampledCard)
{
Expand All @@ -273,143 +255,30 @@ private ReadOnlySpan<double> normProbDist(
}
}

public class EnvCardPicker : ICardPicker
public struct PPOPredictionCache
{
public EnvCardPicker(int envId, VectorizedCardPicker vecAgent)
public PPOPredictionCache(int numEnvs, int steps)
{
this.envId = envId;
this.vecAgent = vecAgent;
}

private int envId;
private VectorizedCardPicker vecAgent;

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
=> vecAgent.ChooseCard(envId, log, possibleCards);
}

public class VectorizedCardPicker
{
public VectorizedCardPicker(
PPOTrainingSettings config,
PPOModel agent,
PPORolloutBuffer memory)
{
this.config = config;
this.agent = agent;
this.memory = memory;

inputs = new (GameLog, Card[], int)[config.BatchSize];
for (int i = 0; i < inputs.Length; i++)
inputs[i].Item2 = new Card[8];
outputs = new Card[config.BatchSize];
barr = new Barrier(config.BatchSize, (b) => predictBatched());

s0 = Matrix2D.Zeros(config.BatchSize, 90);
a0 = Matrix2D.Zeros(config.BatchSize, 1);
piOnehot = Matrix2D.Zeros(config.BatchSize, config.NumActionDims);
piSparse = Matrix2D.Zeros(config.BatchSize, 1);
V = Matrix2D.Zeros(config.BatchSize, 1);
expCache = new PPOTrainBatch(config.BatchSize, config.NumStateDims);
this.numEnvs = numEnvs;
int size = steps * numEnvs;
oldProbs = new double[size];
oldBaselines = new double[size];
}

private PPOTrainingSettings config;
private PPOModel agent;
private PPORolloutBuffer memory;
private GameStateSerializer stateSerializer = new GameStateSerializer();
private PossibleCardPicker possCardsPicker = new PossibleCardPicker();

private int t = 0;
private (GameLog, Card[], int)[] inputs;
private Card[] outputs;
private Matrix2D s0, a0, piOnehot, piSparse, V;
private PPOTrainBatch expCache;

private Barrier barr;
private int numEnvs;
private double[] oldProbs;
private double[] oldBaselines;

public Card ChooseCard(
int envId, GameLog log, ReadOnlySpan<Card> possibleCards)
public void AppendStep(int t, ReadOnlySpan<double> pi, ReadOnlySpan<double> v)
{
inputs[envId].Item1 = log;
possibleCards.CopyTo(inputs[envId].Item2);
inputs[envId].Item3 = possibleCards.Length;
barr.SignalAndWait();
return outputs[envId];
pi.CopyTo(oldProbs.AsSpan(t * numEnvs));
v.CopyTo(oldBaselines.AsSpan(t * numEnvs));
}

private void predictBatched()
public void Export(Span<double> pi, Span<double> v)
{
for (int i = 0; i < config.BatchSize; i++)
{
(var log, var _, int __) = inputs[i];
var state = stateSerializer.SerializeState(log);
unsafe { state.ExportFeatures(s0.Data + i * 90); }
}

agent.Predict(s0, a0, piOnehot, V);

for (int i = 0; i < config.BatchSize; i++)
{
(var _, var cardsCache, int cardsLen) = inputs[i];
var possCards = cardsCache.AsSpan(0, cardsLen);
Span<double> probDistAll;
unsafe { probDistAll = new Span<double>(piOnehot.Data + i * 32, 32); }

var sampledCard = new Card((byte)a0.At(i, 0));
sampledCard = possCardsPicker.PickCard(possCards, probDistAll, sampledCard);

unsafe
{
a0.Data[i] = sampledCard.Id;
piSparse.Data[i] = probDistAll[sampledCard.Id];
}
}

Matrix2D.CopyData(s0, expCache.StatesBefore);
Matrix2D.CopyData(a0, expCache.Actions);
Matrix2D.CopyData(piSparse, expCache.OldProbs);
Matrix2D.CopyData(V, expCache.OldBaselines);

// for (int i = 0; i < config.BatchSize; i++)
// {
// (var log, var _, int __) = inputs[i];
// GameReward.Reward(log);
// }

// TODO: include ways to determine the rewards and terminals
// Matrix2D.CopyData(cacheOnlyLastStep.Rewards, expCache.Rewards);
// Matrix2D.CopyData(cacheOnlyLastStep.Terminals, expCache.Terminals);

memory.AppendStep(expCache, t++);
}
}

public class PPOTrainingSession
{
public void Train()
{
var config = new PPOTrainingSettings();
var rewardFunc = new GameReward();
var memory = new PPORolloutBuffer(config);
var ppoModel = new PPOModel(config);
var cardPicker = new VectorizedCardPicker(config, ppoModel, memory);
var heuristicGameCaller = new HeuristicAgent();
var envProxies = Enumerable.Range(0, config.NumEnvs)
.Select(i => new EnvCardPicker(i, cardPicker)).ToArray();
var agents = Enumerable.Range(0, config.NumEnvs)
.Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray();

var tables = Enumerable.Range(0, config.NumEnvs)
.Select(i => new Table(
new Player(0, agents[i]),
new Player(1, agents[i]),
new Player(2, agents[i]),
new Player(3, agents[i])
)).ToArray();
var sessions = Enumerable.Range(0, config.NumEnvs)
.Select(i => new GameSession(tables[i], new CardsDeck())).ToArray();

// TODO: finish implementation
oldProbs.CopyTo(pi);
oldBaselines.CopyTo(v);
}
}

Expand Down Expand Up @@ -491,25 +360,25 @@ public class PPORolloutBuffer
{
public PPORolloutBuffer(PPOTrainingSettings config)
{
numEnvs = config.NumEnvs;
steps = config.StepsPerUpdate;
NumEnvs = config.NumEnvs;
Steps = config.StepsPerUpdate;
gamma = config.RewardDiscount;
gaeGamma = config.GAEDiscount;

// info: the cache stores an extra timestep at the end
// which facilitates proper GAE computation

int size = steps * numEnvs;
int sizeWithExtraStep = (steps + 1) * numEnvs;
int size = Steps * NumEnvs;
int sizeWithExtraStep = (Steps + 1) * NumEnvs;
cache = new PPOTrainBatch(sizeWithExtraStep, config.NumStateDims);
cacheWithoutLastStep = cache.SliceRows(0, size);
cacheOnlyFirstStep = cache.SliceRows(0, numEnvs);
cacheOnlyLastStep = cache.SliceRows(size, numEnvs);
cacheOnlyFirstStep = cache.SliceRows(0, NumEnvs);
cacheOnlyLastStep = cache.SliceRows(size, NumEnvs);
permCache = Perm.Identity(size);
}

private int numEnvs;
private int steps;
public int NumEnvs;
public int Steps;
private double gamma;
private double gaeGamma;
private PPOTrainBatch cache;
Expand All @@ -518,21 +387,21 @@ public PPORolloutBuffer(PPOTrainingSettings config)
private PPOTrainBatch cacheOnlyLastStep;
private int[] permCache;

public bool IsReadyForModelUpdate(int t) => t > 0 && t % steps == 0;
public bool IsReadyForModelUpdate(int t) => t > 0 && t % Steps == 0;

public void AppendStep(PPOTrainBatch expsOfStep, int t)
{
int offset = IsReadyForModelUpdate(t)
? steps * numEnvs : (t % steps) * numEnvs;

Matrix2D.CopyData(expsOfStep.StatesBefore, cache.StatesBefore.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.Actions, cache.Actions.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.Rewards, cache.Rewards.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.Terminals, cache.Terminals.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.Returns, cache.Returns.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.Advantages, cache.Advantages.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.OldProbs, cache.OldProbs.SliceRows(offset, numEnvs));
Matrix2D.CopyData(expsOfStep.OldBaselines, cache.OldBaselines.SliceRows(offset, numEnvs));
? Steps * NumEnvs : (t % Steps) * NumEnvs;

Matrix2D.CopyData(expsOfStep.StatesBefore, cache.StatesBefore.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Actions, cache.Actions.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Rewards, cache.Rewards.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Terminals, cache.Terminals.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Returns, cache.Returns.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Advantages, cache.Advantages.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.OldProbs, cache.OldProbs.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.OldBaselines, cache.OldBaselines.SliceRows(offset, NumEnvs));
}

public IEnumerable<PPOTrainBatch> SampleDataset(int batchSize, int epochs = 1)
Expand All @@ -557,11 +426,11 @@ private void shuffleDataset()

private void cacheGAE(PPOTrainBatch cache)
{
var nonterm_t1 = Matrix2D.Zeros(1, numEnvs);
var lambda = Matrix2D.Zeros(1, numEnvs);
var delta = Matrix2D.Zeros(1, numEnvs);
var nonterm_t1 = Matrix2D.Zeros(1, NumEnvs);
var lambda = Matrix2D.Zeros(1, NumEnvs);
var delta = Matrix2D.Zeros(1, NumEnvs);

for (int t = steps - 1; t >= 0; t--)
for (int t = Steps - 1; t >= 0; t--)
{
var r_t0 = cache.Rewards.SliceRows(t, 1);
var term_t1 = cache.Terminals.SliceRows(t+1, 1);
Expand Down

0 comments on commit d71e7b9

Please sign in to comment.