From b69d2095f8ed8c4e96a390ac9a081c296225034d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Mon, 20 Nov 2023 23:56:38 +0100 Subject: [PATCH] add first draft for an actor critic agent --- Schafkopf.Training/Algos/PPOAgent.cs | 130 +++++++++++++++++++++++++ Schafkopf.Training/GameState.cs | 18 ++-- Schafkopf.Training/NeuralNet/Layers.cs | 59 +++++++++++ Schafkopf.Training/RandomAgent.cs | 24 +++++ 4 files changed, 225 insertions(+), 6 deletions(-) create mode 100644 Schafkopf.Training/Algos/PPOAgent.cs diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs new file mode 100644 index 0000000..5085388 --- /dev/null +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -0,0 +1,130 @@ + +// TODO: train a policy to predict the likelihood +// of selecting an action in a given state + +public class UniformDistribution +{ + private static readonly Random rng = new Random(); + + public static int Sample(ReadOnlySpan probs) + { + double p = rng.NextDouble(); + double sum = 0; + for (int i = 0; i < probs.Length - 1; i++) + { + sum += probs[i]; + if (p < sum) + return i; + } + return probs.Length - 1; + } +} + +public class PPOAgent : ISchafkopfAIAgent +{ + private FFModel valueModel = new FFModel(new ILayer[] { + new DenseLayer(64), + new ReLULayer(), + new DenseLayer(64), + new ReLULayer(), + new DenseLayer(1), + new FlattenLayer() + }); + + private FFModel strategyModel = + new FFModel( + new ILayer[] { + new DenseLayer(64), + new ReLULayer(), + new DenseLayer(64), + new ReLULayer(), + new DenseLayer(1), + new FlattenLayer(), + new SoftmaxLayer() + }); + + private GameStateSerializer stateSerializer = new GameStateSerializer(); + private Matrix2D featureCache = Matrix2D.Zeros(8, 92); + public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) + { + var x = featureCache; + var s0 = stateSerializer.SerializeState(log); + + int p = 0; + for (int i = 0; i < possibleCards.Length; i++) + { + unsafe + { + var card = possibleCards[i]; + x.Data[p++] = GameEncoding.Encode(card.Type); + x.Data[p++] = GameEncoding.Encode(card.Color); + s0.ExportFeatures(x.Data + p); + p += GameState.NUM_FEATURES; + } + } + + var probDist = strategyModel.PredictBatch(featureCache); + ReadOnlySpan probDistSlice; + unsafe { probDistSlice = new Span(probDist.Data, possibleCards.Length); } + int id = UniformDistribution.Sample(probDistSlice); + return possibleCards[id]; + } + + public void OnGameFinished(GameLog final) + { + throw new NotImplementedException(); + } + + #region Misc + + public bool CallKontra(GameLog log) => false; + + public bool CallRe(GameLog log) => false; + + public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) => false; + + private HeuristicGameCaller caller = + new HeuristicGameCaller(new GameMode[] { GameMode.Sauspiel }); + public GameCall MakeCall( + ReadOnlySpan possibleCalls, + int position, Hand hand, int klopfer) + => caller.MakeCall(possibleCalls, position, hand, klopfer); + + #endregion Misc +} + +public class PPOTrainingSettings +{ + public int NumObsFeatures { get; set; } + public int TotalSteps = 10_000_000; + public double LearnRate = 3e-4; + public double RewardDiscount = 0.99; + public double GAEDiscount = 0.95; + public double ProbClip = 0.2; + public double ValueClip = 0.2; + public double VFCoef = 0.5; + public double EntCoef = 0.01; + public bool NormAdvantages = true; + public bool ClipValues = true; + public int BatchSize = 64; + public int NumEnvs = 32; + public int StepsPerUpdate = 512; + public int UpdateEpochs = 4; + public int NumModelSnapshots = 20; + + public int TrainSteps => TotalSteps / NumEnvs; + public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots; +} + +public class PPOTrainingSession +{ + public void Train() + { + + } +} + +public class PPORolloutBuffer +{ + // +} diff --git a/Schafkopf.Training/GameState.cs b/Schafkopf.Training/GameState.cs index 0122ad5..b6b673a 100644 --- a/Schafkopf.Training/GameState.cs +++ b/Schafkopf.Training/GameState.cs @@ -40,6 +40,8 @@ public void SerializeSarsExps( GameLog completedGame, SarsExp[] exps, Func reward) { + if (completedGame.CardCount != 32) + throw new ArgumentException("Can only process finished games!"); serializeHistory(completedGame, stateBuffer); var actions = completedGame.UnrollActions().GetEnumerator(); @@ -61,16 +63,18 @@ public void SerializeSarsExps( } } + public GameState SerializeState(GameLog liveGame) + { + serializeHistory(liveGame, stateBuffer); + return stateBuffer[liveGame.CardCount - 1]; + } + private int playerPosOfTurn(GameLog log, int t_id, int p_id) => t_id == 8 ? p_id : normPlayerId(p_id, log.Turns[t_id].FirstDrawingPlayerId); private void serializeHistory(GameLog completedGame, GameState[] statesCache) { - if (completedGame.CardCount != 32) - throw new ArgumentException("Can only process finished games!"); - if (statesCache.Length < 36) - throw new ArgumentException(""); - + int timesteps = completedGame.CardCount; var origCall = completedGame.Call; var hands = completedGame.UnrollHands().GetEnumerator(); var scores = completedGame.UnrollAugen().GetEnumerator(); @@ -81,7 +85,7 @@ private void serializeHistory(GameLog completedGame, GameState[] statesCache) }; int t = 0; - foreach (var turn in completedGame.Turns) + for (int t_id = 0; t_id < Math.Ceiling((double)timesteps / 4); t_id++) { scores.MoveNext(); @@ -93,6 +97,8 @@ private void serializeHistory(GameLog completedGame, GameState[] statesCache) var score = scores.Current; var state = statesCache[t].State; serializeState(state, normCalls, hand, t++, allActions, score); + + if (t == timesteps) return; } } diff --git a/Schafkopf.Training/NeuralNet/Layers.cs b/Schafkopf.Training/NeuralNet/Layers.cs index e037c8e..d666e44 100644 --- a/Schafkopf.Training/NeuralNet/Layers.cs +++ b/Schafkopf.Training/NeuralNet/Layers.cs @@ -207,3 +207,62 @@ public void ApplyGrads() // info: layer has no trainable params } } + +public class FlattenLayer : ILayer +{ + public FlattenLayer(int axis = 1) + => this.axis = axis; + + private int axis; + + public LayerCache Cache { get; private set; } + + public int InputDims { get; private set; } + + public int OutputDims { get; private set; } + + public void Compile(int inputDims) + { + InputDims = inputDims; + } + + public void CompileCache(Matrix2D inputs, Matrix2D deltasOut) + { + int flatDims = inputs.NumRows * inputs.NumCols; + OutputDims = axis == 0 ? flatDims : 1; + int batchSize = axis == 0 ? 1 : flatDims; + + Cache = new LayerCache() { + Input = inputs, + Output = Matrix2D.Zeros(batchSize, OutputDims), + DeltasIn = Matrix2D.Zeros(batchSize, OutputDims), + DeltasOut = deltasOut, + Gradients = Matrix2D.Null() + }; + } + + public void Forward() + { + unsafe + { + int dataLen = Cache.Input.NumRows * Cache.Input.NumCols; + for (int i = 0; i < dataLen; i++) + Cache.Output.Data[i] = Cache.Input.Data[i]; + } + } + + public void Backward() + { + unsafe + { + int dataLen = Cache.DeltasIn.NumRows * Cache.DeltasIn.NumCols; + for (int i = 0; i < dataLen; i++) + Cache.DeltasOut.Data[i] = Cache.DeltasIn.Data[i]; + } + } + + public void ApplyGrads() + { + // info: layer has no trainable params + } +} diff --git a/Schafkopf.Training/RandomAgent.cs b/Schafkopf.Training/RandomAgent.cs index 5bc3dff..3e2870e 100644 --- a/Schafkopf.Training/RandomAgent.cs +++ b/Schafkopf.Training/RandomAgent.cs @@ -11,6 +11,20 @@ public GameCall MakeCall( ReadOnlySpan possibleCalls, int position, Hand hand, int klopfer) { + if (allowedModes.Contains(GameMode.Solo)) + { + var call = canCallSolo(possibleCalls, position, hand, klopfer); + if (call.Mode == GameMode.Solo) + return call; + } + + if (allowedModes.Contains(GameMode.Wenz)) + { + var call = canCallWenz(possibleCalls, position, hand, klopfer); + if (call.Mode == GameMode.Wenz) + return call; + } + if (allowedModes.Contains(GameMode.Sauspiel)) { var call = canCallSauspiel(possibleCalls, hand); @@ -57,6 +71,16 @@ private GameCall canCallSauspiel( return sauspielCalls.OrderBy(x => hand.FarbeCount(x.GsuchteFarbe)).First(); } + + private GameCall canCallSolo( + ReadOnlySpan possibleCalls, + int position, Hand hand, int klopfer) + => GameCall.Weiter(); // TODO: implement logic for solo decision + + private GameCall canCallWenz( + ReadOnlySpan possibleCalls, + int position, Hand hand, int klopfer) + => GameCall.Weiter(); // TODO: implement logic for wenz decision } public class RandomAgent : ISchafkopfAIAgent