diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index d422e8c..820d929 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -1,5 +1,39 @@ namespace Schafkopf.Training; +public class PPOTrainingSession +{ + public void Train() + { + // var config = new PPOTrainingSettings(); + // var rewardFunc = new GameReward(); + // var memory = new PPORolloutBuffer(config); + // var predCache = new PPOPredictionCache(config.NumEnvs, config.StepsPerUpdate); + // var ppoModel = new PPOModel(config); + // var cardPicker = new VectorizedCardPicker(config, ppoModel, predCache); + // var vecEnv = new VectorizedCardPickerEnv(cardPicker, config.BatchSize); + // var heuristicGameCaller = new HeuristicAgent(); + // var envProxies = Enumerable.Range(0, config.NumEnvs) + // .Select(i => new EnvCardPicker(i, vecEnv)).ToArray(); + // var agents = Enumerable.Range(0, config.NumEnvs) + // .Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray(); + + // var tables = Enumerable.Range(0, config.NumEnvs) + // .Select(i => new Table( + // new Player(0, agents[i]), + // new Player(1, agents[i]), + // new Player(2, agents[i]), + // new Player(3, agents[i]) + // )).ToArray(); + // var sessions = Enumerable.Range(0, config.NumEnvs) + // .Select(i => new GameSession(tables[i], new CardsDeck())).ToArray(); + + // for (int i = 0; i < 10_000; i++) + // { + // var games = sessions.AsParallel().Select(sess => sess.ProcessGame()); + // } + } +} + public class UniformDistribution { public UniformDistribution(int? seed = null) @@ -62,28 +96,12 @@ public PPOModel(PPOTrainingSettings config) private Matrix2D featureCache; private UniformDistribution uniform = new UniformDistribution(); - public void Predict( - Matrix2D s0, Matrix2D outActions, - Matrix2D outPiOnehot, Matrix2D outV) + public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV) { - int batchSize = s0.NumRows; - int numFeatures = s0.NumCols; - var predPi = strategy.PredictBatch(s0); var predV = valueFunc.PredictBatch(s0); Matrix2D.CopyData(predPi, outPiOnehot); - - for (int i = 0; i < batchSize; i++) - { - unsafe - { - var probDist = new Span( - predPi.Data + i * numFeatures, numFeatures); - int action = uniform.Sample(probDist); - outActions.Data[i] = action; - outV.Data[i] = predV.At(0, 0); - } - } + Matrix2D.CopyData(predV, outV); } public void Train(PPORolloutBuffer memory) @@ -197,45 +215,6 @@ public class PPOTrainingSettings public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots; } -public interface ICardPicker -{ - Card ChooseCard(GameLog log, ReadOnlySpan possibleCards); -} - -public class ComposedAgent : ISchafkopfAIAgent -{ - public ComposedAgent( - ISchafkopfAIAgent agent, - ICardPicker cardPicker) - { - this.agent = agent; - this.cardPicker = cardPicker; - } - - private ISchafkopfAIAgent agent; - private ICardPicker cardPicker; - - public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) - => cardPicker.ChooseCard(log, possibleCards); - - public GameCall MakeCall( - ReadOnlySpan possibleCalls, - int position, Hand hand, int klopfer) - => agent.MakeCall(possibleCalls, position, hand, klopfer); - - public bool CallKontra(GameLog log) - => agent.CallKontra(log); - - public bool CallRe(GameLog log) - => agent.CallRe(log); - - public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) - => agent.IsKlopfer(position, firstFourCards); - - public void OnGameFinished(GameLog final) - => agent.OnGameFinished(final); -} - public class PossibleCardPicker { private UniformDistribution uniform = new UniformDistribution(); @@ -247,6 +226,9 @@ public Card PickCard( => canPlaySampledCard(possibleCards, sampledCard) ? sampledCard : possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) + => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + private bool canPlaySampledCard( ReadOnlySpan possibleCards, Card sampledCard) { @@ -273,143 +255,30 @@ private ReadOnlySpan normProbDist( } } -public class EnvCardPicker : ICardPicker +public struct PPOPredictionCache { - public EnvCardPicker(int envId, VectorizedCardPicker vecAgent) + public PPOPredictionCache(int numEnvs, int steps) { - this.envId = envId; - this.vecAgent = vecAgent; - } - - private int envId; - private VectorizedCardPicker vecAgent; - - public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) - => vecAgent.ChooseCard(envId, log, possibleCards); -} - -public class VectorizedCardPicker -{ - public VectorizedCardPicker( - PPOTrainingSettings config, - PPOModel agent, - PPORolloutBuffer memory) - { - this.config = config; - this.agent = agent; - this.memory = memory; - - inputs = new (GameLog, Card[], int)[config.BatchSize]; - for (int i = 0; i < inputs.Length; i++) - inputs[i].Item2 = new Card[8]; - outputs = new Card[config.BatchSize]; - barr = new Barrier(config.BatchSize, (b) => predictBatched()); - - s0 = Matrix2D.Zeros(config.BatchSize, 90); - a0 = Matrix2D.Zeros(config.BatchSize, 1); - piOnehot = Matrix2D.Zeros(config.BatchSize, config.NumActionDims); - piSparse = Matrix2D.Zeros(config.BatchSize, 1); - V = Matrix2D.Zeros(config.BatchSize, 1); - expCache = new PPOTrainBatch(config.BatchSize, config.NumStateDims); + this.numEnvs = numEnvs; + int size = steps * numEnvs; + oldProbs = new double[size]; + oldBaselines = new double[size]; } - private PPOTrainingSettings config; - private PPOModel agent; - private PPORolloutBuffer memory; - private GameStateSerializer stateSerializer = new GameStateSerializer(); - private PossibleCardPicker possCardsPicker = new PossibleCardPicker(); - - private int t = 0; - private (GameLog, Card[], int)[] inputs; - private Card[] outputs; - private Matrix2D s0, a0, piOnehot, piSparse, V; - private PPOTrainBatch expCache; - - private Barrier barr; + private int numEnvs; + private double[] oldProbs; + private double[] oldBaselines; - public Card ChooseCard( - int envId, GameLog log, ReadOnlySpan possibleCards) + public void AppendStep(int t, ReadOnlySpan pi, ReadOnlySpan v) { - inputs[envId].Item1 = log; - possibleCards.CopyTo(inputs[envId].Item2); - inputs[envId].Item3 = possibleCards.Length; - barr.SignalAndWait(); - return outputs[envId]; + pi.CopyTo(oldProbs.AsSpan(t * numEnvs)); + v.CopyTo(oldBaselines.AsSpan(t * numEnvs)); } - private void predictBatched() + public void Export(Span pi, Span v) { - for (int i = 0; i < config.BatchSize; i++) - { - (var log, var _, int __) = inputs[i]; - var state = stateSerializer.SerializeState(log); - unsafe { state.ExportFeatures(s0.Data + i * 90); } - } - - agent.Predict(s0, a0, piOnehot, V); - - for (int i = 0; i < config.BatchSize; i++) - { - (var _, var cardsCache, int cardsLen) = inputs[i]; - var possCards = cardsCache.AsSpan(0, cardsLen); - Span probDistAll; - unsafe { probDistAll = new Span(piOnehot.Data + i * 32, 32); } - - var sampledCard = new Card((byte)a0.At(i, 0)); - sampledCard = possCardsPicker.PickCard(possCards, probDistAll, sampledCard); - - unsafe - { - a0.Data[i] = sampledCard.Id; - piSparse.Data[i] = probDistAll[sampledCard.Id]; - } - } - - Matrix2D.CopyData(s0, expCache.StatesBefore); - Matrix2D.CopyData(a0, expCache.Actions); - Matrix2D.CopyData(piSparse, expCache.OldProbs); - Matrix2D.CopyData(V, expCache.OldBaselines); - - // for (int i = 0; i < config.BatchSize; i++) - // { - // (var log, var _, int __) = inputs[i]; - // GameReward.Reward(log); - // } - - // TODO: include ways to determine the rewards and terminals - // Matrix2D.CopyData(cacheOnlyLastStep.Rewards, expCache.Rewards); - // Matrix2D.CopyData(cacheOnlyLastStep.Terminals, expCache.Terminals); - - memory.AppendStep(expCache, t++); - } -} - -public class PPOTrainingSession -{ - public void Train() - { - var config = new PPOTrainingSettings(); - var rewardFunc = new GameReward(); - var memory = new PPORolloutBuffer(config); - var ppoModel = new PPOModel(config); - var cardPicker = new VectorizedCardPicker(config, ppoModel, memory); - var heuristicGameCaller = new HeuristicAgent(); - var envProxies = Enumerable.Range(0, config.NumEnvs) - .Select(i => new EnvCardPicker(i, cardPicker)).ToArray(); - var agents = Enumerable.Range(0, config.NumEnvs) - .Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray(); - - var tables = Enumerable.Range(0, config.NumEnvs) - .Select(i => new Table( - new Player(0, agents[i]), - new Player(1, agents[i]), - new Player(2, agents[i]), - new Player(3, agents[i]) - )).ToArray(); - var sessions = Enumerable.Range(0, config.NumEnvs) - .Select(i => new GameSession(tables[i], new CardsDeck())).ToArray(); - - // TODO: finish implementation + oldProbs.CopyTo(pi); + oldBaselines.CopyTo(v); } } @@ -491,25 +360,25 @@ public class PPORolloutBuffer { public PPORolloutBuffer(PPOTrainingSettings config) { - numEnvs = config.NumEnvs; - steps = config.StepsPerUpdate; + NumEnvs = config.NumEnvs; + Steps = config.StepsPerUpdate; gamma = config.RewardDiscount; gaeGamma = config.GAEDiscount; // info: the cache stores an extra timestep at the end // which facilitates proper GAE computation - int size = steps * numEnvs; - int sizeWithExtraStep = (steps + 1) * numEnvs; + int size = Steps * NumEnvs; + int sizeWithExtraStep = (Steps + 1) * NumEnvs; cache = new PPOTrainBatch(sizeWithExtraStep, config.NumStateDims); cacheWithoutLastStep = cache.SliceRows(0, size); - cacheOnlyFirstStep = cache.SliceRows(0, numEnvs); - cacheOnlyLastStep = cache.SliceRows(size, numEnvs); + cacheOnlyFirstStep = cache.SliceRows(0, NumEnvs); + cacheOnlyLastStep = cache.SliceRows(size, NumEnvs); permCache = Perm.Identity(size); } - private int numEnvs; - private int steps; + public int NumEnvs; + public int Steps; private double gamma; private double gaeGamma; private PPOTrainBatch cache; @@ -518,21 +387,21 @@ public PPORolloutBuffer(PPOTrainingSettings config) private PPOTrainBatch cacheOnlyLastStep; private int[] permCache; - public bool IsReadyForModelUpdate(int t) => t > 0 && t % steps == 0; + public bool IsReadyForModelUpdate(int t) => t > 0 && t % Steps == 0; public void AppendStep(PPOTrainBatch expsOfStep, int t) { int offset = IsReadyForModelUpdate(t) - ? steps * numEnvs : (t % steps) * numEnvs; - - Matrix2D.CopyData(expsOfStep.StatesBefore, cache.StatesBefore.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.Actions, cache.Actions.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.Rewards, cache.Rewards.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.Terminals, cache.Terminals.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.Returns, cache.Returns.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.Advantages, cache.Advantages.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.OldProbs, cache.OldProbs.SliceRows(offset, numEnvs)); - Matrix2D.CopyData(expsOfStep.OldBaselines, cache.OldBaselines.SliceRows(offset, numEnvs)); + ? Steps * NumEnvs : (t % Steps) * NumEnvs; + + Matrix2D.CopyData(expsOfStep.StatesBefore, cache.StatesBefore.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.Actions, cache.Actions.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.Rewards, cache.Rewards.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.Terminals, cache.Terminals.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.Returns, cache.Returns.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.Advantages, cache.Advantages.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.OldProbs, cache.OldProbs.SliceRows(offset, NumEnvs)); + Matrix2D.CopyData(expsOfStep.OldBaselines, cache.OldBaselines.SliceRows(offset, NumEnvs)); } public IEnumerable SampleDataset(int batchSize, int epochs = 1) @@ -557,11 +426,11 @@ private void shuffleDataset() private void cacheGAE(PPOTrainBatch cache) { - var nonterm_t1 = Matrix2D.Zeros(1, numEnvs); - var lambda = Matrix2D.Zeros(1, numEnvs); - var delta = Matrix2D.Zeros(1, numEnvs); + var nonterm_t1 = Matrix2D.Zeros(1, NumEnvs); + var lambda = Matrix2D.Zeros(1, NumEnvs); + var delta = Matrix2D.Zeros(1, NumEnvs); - for (int t = steps - 1; t >= 0; t--) + for (int t = Steps - 1; t >= 0; t--) { var r_t0 = cache.Rewards.SliceRows(t, 1); var term_t1 = cache.Terminals.SliceRows(t+1, 1);