From c24f281f665a02e71bf7a87887d6ff7ab445d407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Thu, 30 Nov 2023 13:08:04 +0100 Subject: [PATCH] restructure code --- Schafkopf.Training.Tests/ReplayMemoryTests.cs | 72 +++++------ Schafkopf.Training/Algos/Distributions.cs | 24 ++++ Schafkopf.Training/Algos/HeuristicAgent.cs | 37 +++++- Schafkopf.Training/Algos/MDP.cs | 102 ++++++++++----- Schafkopf.Training/Algos/PPOAgent.cs | 119 +----------------- Schafkopf.Training/Dataset.cs | 27 +++- Schafkopf.Training/GameState.cs | 2 +- Schafkopf.Training/ReplayMemory.cs | 117 ----------------- 8 files changed, 192 insertions(+), 308 deletions(-) create mode 100644 Schafkopf.Training/Algos/Distributions.cs delete mode 100644 Schafkopf.Training/ReplayMemory.cs diff --git a/Schafkopf.Training.Tests/ReplayMemoryTests.cs b/Schafkopf.Training.Tests/ReplayMemoryTests.cs index 5933e60..1ebcb9a 100644 --- a/Schafkopf.Training.Tests/ReplayMemoryTests.cs +++ b/Schafkopf.Training.Tests/ReplayMemoryTests.cs @@ -1,47 +1,47 @@ -namespace Schafkopf.Training.Tests; +// namespace Schafkopf.Training.Tests; -public class ReplayMemoryTests -{ - [Fact] - public void Test_CanFillCacheUntilOverflow() - { - var memory = new ReplayMemory(100); +// public class ReplayMemoryTests +// { +// [Fact] +// public void Test_CanFillCacheUntilOverflow() +// { +// var memory = new ReplayMemory(100); - for (int i = 0; i < 50; i++) - memory.Append(new SarsExp()); +// for (int i = 0; i < 50; i++) +// memory.Append(new SarsExp()); - Assert.Equal(50, memory.Size); +// Assert.Equal(50, memory.Size); - for (int i = 0; i < 50; i++) - memory.Append(new SarsExp()); +// for (int i = 0; i < 50; i++) +// memory.Append(new SarsExp()); - Assert.Equal(100, memory.Size); - } +// Assert.Equal(100, memory.Size); +// } - [Fact] - public void Test_CanInsertIntoOverflowingCache() - { - var memory = new ReplayMemory(100); +// [Fact] +// public void Test_CanInsertIntoOverflowingCache() +// { +// var memory = new ReplayMemory(100); - for (int i = 0; i < 200; i++) - memory.Append(new SarsExp()); +// for (int i = 0; i < 200; i++) +// memory.Append(new SarsExp()); - Assert.Equal(100, memory.Size); - } +// Assert.Equal(100, memory.Size); +// } - [Fact(Skip = "requires the states to be initialized with unique data")] - public void Test_CanReplaceOverflowingDataWithNewData() - { - var memory = new ReplayMemory(100); - var overflowingData = Enumerable.Range(0, 50).Select(x => new SarsExp()).ToArray(); - var insertedData = Enumerable.Range(0, 100).Select(x => new SarsExp()).ToArray(); +// [Fact(Skip = "requires the states to be initialized with unique data")] +// public void Test_CanReplaceOverflowingDataWithNewData() +// { +// var memory = new ReplayMemory(100); +// var overflowingData = Enumerable.Range(0, 50).Select(x => new SarsExp()).ToArray(); +// var insertedData = Enumerable.Range(0, 100).Select(x => new SarsExp()).ToArray(); - foreach (var exp in overflowingData) - memory.Append(exp); - foreach (var exp in insertedData) - memory.Append(exp); +// foreach (var exp in overflowingData) +// memory.Append(exp); +// foreach (var exp in insertedData) +// memory.Append(exp); - Assert.True(overflowingData.All(exp => !memory.Contains(exp))); - Assert.True(insertedData.All(exp => memory.Contains(exp))); - } -} +// Assert.True(overflowingData.All(exp => !memory.Contains(exp))); +// Assert.True(insertedData.All(exp => memory.Contains(exp))); +// } +// } diff --git a/Schafkopf.Training/Algos/Distributions.cs b/Schafkopf.Training/Algos/Distributions.cs new file mode 100644 index 0000000..7bb75a9 --- /dev/null +++ b/Schafkopf.Training/Algos/Distributions.cs @@ -0,0 +1,24 @@ +namespace Schafkopf.Training; + +public class UniformDistribution +{ + public UniformDistribution(int? seed = null) + => rng = seed != null ? new Random(seed.Value) : new Random(); + + private Random rng; + + public int Sample(ReadOnlySpan probs) + { + double p = rng.NextDouble(); + double sum = 0; + for (int i = 0; i < probs.Length - 1; i++) + { + sum += probs[i]; + if (p < sum) + return i; + } + return probs.Length - 1; + } + + public int Sample(int numClasses) => rng.Next(0, numClasses); +} diff --git a/Schafkopf.Training/Algos/HeuristicAgent.cs b/Schafkopf.Training/Algos/HeuristicAgent.cs index e99c0a2..c6a5880 100644 --- a/Schafkopf.Training/Algos/HeuristicAgent.cs +++ b/Schafkopf.Training/Algos/HeuristicAgent.cs @@ -2,9 +2,9 @@ namespace Schafkopf.Training; public class HeuristicAgent : ISchafkopfAIAgent { - private Random rng = new Random(); private HeuristicGameCaller caller = new HeuristicGameCaller(new GameMode[] { GameMode.Sauspiel }); + private HeuristicCardPicker cardPicker = new HeuristicCardPicker(); public GameCall MakeCall( ReadOnlySpan possibleCalls, @@ -12,7 +12,7 @@ public GameCall MakeCall( => caller.MakeCall(possibleCalls, position, hand, klopfer); public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) - => possibleCards[rng.Next(0, possibleCards.Length)]; + => cardPicker.ChooseCard(log, possibleCards); public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) => false; public bool CallKontra(GameLog log) => false; @@ -20,6 +20,39 @@ public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) public void OnGameFinished(GameLog final) => throw new NotImplementedException(); } +public class HeuristicCardPicker +{ + private Random rng = new Random(); + + public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) + { + if (log.Call.Mode == GameMode.Solo) + return chooseCardForSolo(log, possibleCards); + else if (log.Call.Mode == GameMode.Wenz) + return chooseCardForWenz(log, possibleCards); + else // if (log.Call.Mode == GameMode.Sauspiel) + return chooseCardForSauspiel(log, possibleCards); + } + + private Card chooseCardForSolo(GameLog log, ReadOnlySpan possibleCards) + { + // TODO: implement heuristic + return possibleCards[rng.Next(0, possibleCards.Length)]; + } + + private Card chooseCardForWenz(GameLog log, ReadOnlySpan possibleCards) + { + // TODO: implement heuristic + return possibleCards[rng.Next(0, possibleCards.Length)]; + } + + private Card chooseCardForSauspiel(GameLog log, ReadOnlySpan possibleCards) + { + // TODO: implement heuristic + return possibleCards[rng.Next(0, possibleCards.Length)]; + } +} + public class HeuristicGameCaller { public HeuristicGameCaller(IEnumerable modes) diff --git a/Schafkopf.Training/Algos/MDP.cs b/Schafkopf.Training/Algos/MDP.cs index e636a2f..ee296b3 100644 --- a/Schafkopf.Training/Algos/MDP.cs +++ b/Schafkopf.Training/Algos/MDP.cs @@ -2,49 +2,24 @@ namespace Schafkopf.Training; public class CardPickerExpCollector { - public CardPickerExpCollector( - PPOModel strategy, PossibleCardPicker cardSampler) + public CardPickerExpCollector(PPOModel strategy) { this.strategy = strategy; - this.cardSampler = cardSampler; } private GameRules rules = new GameRules(); private GameStateSerializer stateSerializer = new GameStateSerializer(); private PPOModel strategy; - private PossibleCardPicker cardSampler; - - private struct TurnBatches - { - public TurnBatches(int numSessions) - { - s0Batches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 90)).ToArray(); - a0Batches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); - piBatches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); - piSparseBatches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); - vBatches = Enumerable.Range(0, 4) - .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); - } - - public Matrix2D[] s0Batches { get; set; } - public Matrix2D[] a0Batches { get; set; } - public Matrix2D[] piBatches { get; set; } - public Matrix2D[] piSparseBatches { get; set; } - public Matrix2D[] vBatches { get; set; } - } + private PossibleCardPicker cardSampler = new PossibleCardPicker(); public void Collect(PPORolloutBuffer buffer) { if (buffer.NumEnvs % 4 != 0) throw new ArgumentException("The number of envs needs to be " + "divisible by 4 because 4 agents are playing the game!"); - if (buffer.Steps % 8 != 0) - throw new ArgumentException("The number of steps needs to be " - + "divisible by 8 because each agent plays 8 cards per game!"); + // if (buffer.Steps % 8 != 0) + // throw new ArgumentException("The number of steps needs to be " + // + "divisible by 8 because each agent plays 8 cards per game!"); int numGames = buffer.Steps / 8; int numSessions = buffer.NumEnvs / 4; @@ -95,7 +70,7 @@ private void fillBuffer( unsafe { expBuf.Actions.Data[rowid] = a0Batch.Data[envId]; - expBuf.Rewards.Data[rowid] = rewards.Data[envId]; + expBuf.Rewards.Data[rowid] = r1Batch.Data[envId]; expBuf.Terminals.Data[rowid] = t_id == 7 ? 1 : 0; expBuf.OldProbs.Data[rowid] = piSparseBatch.Data[envId]; expBuf.OldBaselines.Data[rowid] = vBatch.Data[envId]; @@ -136,7 +111,7 @@ private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batc for (int envId = 0; envId < states.Length; envId++) { var s0 = stateSerializer.SerializeState(states[envId]); - unsafe { s0.ExportFeatures(s0Batch.Data + envId * 90); } + s0.ExportFeatures(s0Batch.SliceRowsRaw(envId, 1)); } strategy.Predict(s0Batch, piBatch, vBatch); @@ -164,6 +139,69 @@ private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batc } } } + + private struct TurnBatches + { + public TurnBatches(int numSessions) + { + s0Batches = Enumerable.Range(0, 4) + .Select(i => Matrix2D.Zeros(numSessions, 90)).ToArray(); + a0Batches = Enumerable.Range(0, 4) + .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); + piBatches = Enumerable.Range(0, 4) + .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); + piSparseBatches = Enumerable.Range(0, 4) + .Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray(); + vBatches = Enumerable.Range(0, 4) + .Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray(); + } + + public Matrix2D[] s0Batches { get; set; } + public Matrix2D[] a0Batches { get; set; } + public Matrix2D[] piBatches { get; set; } + public Matrix2D[] piSparseBatches { get; set; } + public Matrix2D[] vBatches { get; set; } + } + + private class PossibleCardPicker + { + private UniformDistribution uniform = new UniformDistribution(); + + public Card PickCard( + ReadOnlySpan possibleCards, + ReadOnlySpan predPi, + Card sampledCard) + => canPlaySampledCard(possibleCards, sampledCard) ? sampledCard + : possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + + public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) + => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + + private bool canPlaySampledCard( + ReadOnlySpan possibleCards, Card sampledCard) + { + foreach (var card in possibleCards) + if (card == sampledCard) + return true; + return false; + } + + private double[] probDistCache = new double[8]; + private ReadOnlySpan normProbDist( + ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) + { + double probSum = 0; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; + for (int i = 0; i < possibleCards.Length; i++) + probSum += probDistCache[i]; + double scale = 1 / probSum; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] *= scale; + + return probDistCache.AsSpan().Slice(0, possibleCards.Length); + } + } } public class CardPickerEnv diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index 527c20a..07f6b97 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -4,59 +4,10 @@ public class PPOTrainingSession { public void Train() { - // var config = new PPOTrainingSettings(); - // var rewardFunc = new GameReward(); - // var memory = new PPORolloutBuffer(config); - // var predCache = new PPOPredictionCache(config.NumEnvs, config.StepsPerUpdate); - // var ppoModel = new PPOModel(config); - // var cardPicker = new VectorizedCardPicker(config, ppoModel, predCache); - // var vecEnv = new VectorizedCardPickerEnv(cardPicker, config.BatchSize); - // var heuristicGameCaller = new HeuristicAgent(); - // var envProxies = Enumerable.Range(0, config.NumEnvs) - // .Select(i => new EnvCardPicker(i, vecEnv)).ToArray(); - // var agents = Enumerable.Range(0, config.NumEnvs) - // .Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray(); - - // var tables = Enumerable.Range(0, config.NumEnvs) - // .Select(i => new Table( - // new Player(0, agents[i]), - // new Player(1, agents[i]), - // new Player(2, agents[i]), - // new Player(3, agents[i]) - // )).ToArray(); - // var sessions = Enumerable.Range(0, config.NumEnvs) - // .Select(i => new GameSession(tables[i], new CardsDeck())).ToArray(); - - // for (int i = 0; i < 10_000; i++) - // { - // var games = sessions.AsParallel().Select(sess => sess.ProcessGame()); - // } + } } -public class UniformDistribution -{ - public UniformDistribution(int? seed = null) - => rng = seed != null ? new Random(seed.Value) : new Random(); - - private Random rng; - - public int Sample(ReadOnlySpan probs) - { - double p = rng.NextDouble(); - double sum = 0; - for (int i = 0; i < probs.Length - 1; i++) - { - sum += probs[i]; - if (p < sum) - return i; - } - return probs.Length - 1; - } - - public int Sample(int numClasses) => rng.Next(0, numClasses); -} - public class PPOModel { public PPOModel(PPOTrainingSettings config) @@ -94,7 +45,6 @@ public PPOModel(PPOTrainingSettings config) private IOptimizer valueFuncOpt; private Matrix2D featureCache; - private UniformDistribution uniform = new UniformDistribution(); public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV) { @@ -215,73 +165,6 @@ public class PPOTrainingSettings public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots; } -public class PossibleCardPicker -{ - private UniformDistribution uniform = new UniformDistribution(); - - public Card PickCard( - ReadOnlySpan possibleCards, - ReadOnlySpan predPi, - Card sampledCard) - => canPlaySampledCard(possibleCards, sampledCard) ? sampledCard - : possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; - - public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) - => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; - - private bool canPlaySampledCard( - ReadOnlySpan possibleCards, Card sampledCard) - { - foreach (var card in possibleCards) - if (card == sampledCard) - return true; - return false; - } - - private double[] probDistCache = new double[8]; - private ReadOnlySpan normProbDist( - ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) - { - double probSum = 0; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; - for (int i = 0; i < possibleCards.Length; i++) - probSum += probDistCache[i]; - double scale = 1 / probSum; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] *= scale; - - return probDistCache.AsSpan().Slice(0, possibleCards.Length); - } -} - -public struct PPOPredictionCache -{ - public PPOPredictionCache(int numEnvs, int steps) - { - this.numEnvs = numEnvs; - int size = steps * numEnvs; - oldProbs = new double[size]; - oldBaselines = new double[size]; - } - - private int numEnvs; - private double[] oldProbs; - private double[] oldBaselines; - - public void AppendStep(int t, ReadOnlySpan pi, ReadOnlySpan v) - { - pi.CopyTo(oldProbs.AsSpan(t * numEnvs)); - v.CopyTo(oldBaselines.AsSpan(t * numEnvs)); - } - - public void Export(Span pi, Span v) - { - oldProbs.CopyTo(pi); - oldBaselines.CopyTo(v); - } -} - public struct PPOTrainBatch { public PPOTrainBatch(int size, int numStateDims) diff --git a/Schafkopf.Training/Dataset.cs b/Schafkopf.Training/Dataset.cs index 11297b5..ab991d6 100644 --- a/Schafkopf.Training/Dataset.cs +++ b/Schafkopf.Training/Dataset.cs @@ -1,5 +1,25 @@ namespace Schafkopf.Training; +public struct SarsExp : IEquatable +{ + public SarsExp() { } + + public GameState StateBefore = new GameState(); + public GameState StateAfter = new GameState(); + public Card Action = new Card(); + public double Reward = 0.0; + public bool IsTerminal = false; + + public bool Equals(SarsExp other) + => StateBefore.Equals(other.StateBefore) + && StateAfter.Equals(other.StateAfter) + && Action == other.Action + && Reward == other.Reward + && IsTerminal == other.IsTerminal; + + public override int GetHashCode() => 0; +} + public class SupervisedSchafkopfDataset { public static FlatFeatureDataset GenerateDataset( @@ -24,8 +44,9 @@ private static (Matrix2D, Matrix2D) generateDataset(int size) var card = exp.Action; x.Data[p++] = GameEncoding.Encode(card.Type); x.Data[p++] = GameEncoding.Encode(card.Color); - exp.StateBefore.ExportFeatures(x.Data + p); - p += GameState.NUM_FEATURES; + var stateDest = new Span(x.Data + p, GameState.NUM_FEATURES); + exp.StateBefore.ExportFeatures(stateDest); + p += stateDest.Length; y.Data[i++] = exp.Reward; } } @@ -39,6 +60,8 @@ private static IEnumerable generateExperiences( { var gameCaller = new HeuristicGameCaller( new GameMode[] { GameMode.Sauspiel }); + + // TODO: supervised transfer learning requires pre-trained agent / heuristic var agent = new RandomAgent(gameCaller); var table = new Table( new Player(0, agent), new Player(1, agent), diff --git a/Schafkopf.Training/GameState.cs b/Schafkopf.Training/GameState.cs index c0c4550..53f6fb7 100644 --- a/Schafkopf.Training/GameState.cs +++ b/Schafkopf.Training/GameState.cs @@ -11,7 +11,7 @@ public GameState() { } public void LoadFeatures(double[] other) => Array.Copy(other, State, NUM_FEATURES); - public unsafe void ExportFeatures(double* other) + public unsafe void ExportFeatures(Span other) { for (int i = 0; i < NUM_FEATURES; i++) other[i] = State[i]; diff --git a/Schafkopf.Training/ReplayMemory.cs b/Schafkopf.Training/ReplayMemory.cs deleted file mode 100644 index a85a3e0..0000000 --- a/Schafkopf.Training/ReplayMemory.cs +++ /dev/null @@ -1,117 +0,0 @@ -namespace Schafkopf.Training; - -public struct SarsExp : IEquatable -{ - public SarsExp() { } - - public GameState StateBefore = new GameState(); - public GameState StateAfter = new GameState(); - public Card Action = new Card(); - public double Reward = 0.0; - public bool IsTerminal = false; - - public bool Equals(SarsExp other) - => StateBefore.Equals(other.StateBefore) - && StateAfter.Equals(other.StateAfter) - && Action == other.Action - && Reward == other.Reward - && IsTerminal == other.IsTerminal; - - public override int GetHashCode() => 0; -} - -public struct ACSarsExp : IEquatable -{ - public ACSarsExp() { } - - public GameState StateBefore = new GameState(); - public GameState StateAfter = new GameState(); - public Card Action = new Card(); - public double Reward = 0.0; - public bool IsTerminal = false; - public double OldProb = 0.0; - public double OldBaseline = 0.0; - - public bool Equals(ACSarsExp other) - => StateBefore.Equals(other.StateBefore) - && StateAfter.Equals(other.StateAfter) - && Action == other.Action - && Reward == other.Reward - && IsTerminal == other.IsTerminal - && OldProb == other.OldProb; - - public override int GetHashCode() => 0; -} - -public class ReplayMemory -{ - private static readonly Random rng = new Random(); - - public ReplayMemory(int size) - { - totalSize = size; - memory = new SarsExp[totalSize]; - for (int i = 0; i < size; i++) - memory[i] = new SarsExp(); - } - - private int totalSize; - private bool isFilled = false; - private bool overflow = false; - private int insertPos = 0; - private SarsExp[] memory; - - public int Size => isFilled ? totalSize : insertPos; - - public void Append(SarsExp exp) - { - var origExp = memory[insertPos]; - origExp.StateBefore.LoadFeatures(exp.StateBefore.State); - origExp.StateAfter.LoadFeatures(exp.StateAfter.State); - memory[insertPos++] = origExp; - - overflow = insertPos == totalSize; - isFilled = isFilled || overflow; - insertPos %= totalSize; - } - - public void AppendBatched(SarsExp[] exps) - { - bool overflow = false; - foreach (var exp in exps) - { - Append(exp); - overflow |= this.overflow; - } - this.overflow = overflow; - } - - public void SampleRandom(SarsExp[] cache) - { - int maxId = isFilled ? totalSize : insertPos; - for (int i = 0; i < cache.Length; i++) - cache[i] = memory[rng.Next() % maxId]; - } - - public void SampleBatched(SarsExp[] cache) - { - int maxId = isFilled ? totalSize : insertPos; - int offset = rng.Next(0, totalSize - cache.Length); - for (int i = 0; i < cache.Length; i++) - cache[i] = memory[(offset + i) % maxId]; - } - - public void Shuffle() - { - for (int i = 0; i < totalSize; i++) - { - int j = rng.Next(i, totalSize); - var temp = memory[j]; - memory[j] = memory[i]; - memory[i] = temp; - } - } - - public bool Contains(SarsExp exp) - => memory.Contains(exp); -}