diff --git a/Schafkopf.Training/Algos/MDP.cs b/Schafkopf.Training/Algos/MDP.cs index ee296b3..f2324fd 100644 --- a/Schafkopf.Training/Algos/MDP.cs +++ b/Schafkopf.Training/Algos/MDP.cs @@ -162,46 +162,6 @@ public TurnBatches(int numSessions) public Matrix2D[] piSparseBatches { get; set; } public Matrix2D[] vBatches { get; set; } } - - private class PossibleCardPicker - { - private UniformDistribution uniform = new UniformDistribution(); - - public Card PickCard( - ReadOnlySpan possibleCards, - ReadOnlySpan predPi, - Card sampledCard) - => canPlaySampledCard(possibleCards, sampledCard) ? sampledCard - : possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; - - public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) - => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; - - private bool canPlaySampledCard( - ReadOnlySpan possibleCards, Card sampledCard) - { - foreach (var card in possibleCards) - if (card == sampledCard) - return true; - return false; - } - - private double[] probDistCache = new double[8]; - private ReadOnlySpan normProbDist( - ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) - { - double probSum = 0; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; - for (int i = 0; i < possibleCards.Length; i++) - probSum += probDistCache[i]; - double scale = 1 / probSum; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] *= scale; - - return probDistCache.AsSpan().Slice(0, possibleCards.Length); - } - } } public class CardPickerEnv diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index 07f6b97..af16fd7 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -1,13 +1,95 @@ + namespace Schafkopf.Training; +public class PPOTrainingSettings +{ + public int TotalSteps = 10_000_000; + public double LearnRate = 3e-4; + public double RewardDiscount = 0.99; + public double GAEDiscount = 0.95; + public double ProbClip = 0.2; + public double ValueClip = 0.2; + public double VFCoef = 0.5; + public double EntCoef = 0.01; + public bool NormAdvantages = true; + public bool ClipValues = true; + public int BatchSize = 64; + public int NumEnvs = 32; + public int NumStateDims = 90; + public int NumActionDims = 32; + public int StepsPerUpdate = 512; + public int UpdateEpochs = 10; + public int NumModelSnapshots = 20; + + public int TrainSteps => TotalSteps / NumEnvs; + public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots; + public int NumTrainings => TrainSteps / StepsPerUpdate; +} + public class PPOTrainingSession { - public void Train() + public PPOModel Train(PPOTrainingSettings config) { - + var model = new PPOModel(config); + var rollout = new PPORolloutBuffer(config); + var exps = new CardPickerExpCollector(model); + var benchmark = new RandomPlayBenchmark(); + var agent = new PPOAgent(model); + + for (int ep = 0; ep < config.NumTrainings; ep++) + { + exps.Collect(rollout); + model.Train(rollout); + + if ((ep + 1) % 10 == 0) + { + double winRate = benchmark.Benchmark(agent); + Console.WriteLine($"epoch {ep}: win rate vs. random agents is {winRate}"); + } + } + + return model; } } +public class PPOAgent : ISchafkopfAIAgent +{ + public PPOAgent(PPOModel model) + { + this.model = model; + } + + private PPOModel model; + private HeuristicAgent heuristicAgent = new HeuristicAgent(); + private GameStateSerializer stateSerializer = new GameStateSerializer(); + private PossibleCardPicker sampler = new PossibleCardPicker(); + private UniformDistribution uniform = new UniformDistribution(); + + private Matrix2D s0 = Matrix2D.Zeros(1, 90); + private Matrix2D piOh = Matrix2D.Zeros(1, 32); + private Matrix2D V = Matrix2D.Zeros(1, 1); + + public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) + { + var state = stateSerializer.SerializeState(log); + state.ExportFeatures(s0.SliceRowsRaw(0, 1)); + model.Predict(s0, piOh, V); + var predDist = piOh.SliceRowsRaw(0, 1); + var card = new Card((byte)uniform.Sample(predDist)); + return sampler.PickCard(possibleCards, predDist, card); + } + + public bool CallKontra(GameLog log) => heuristicAgent.CallKontra(log); + public bool CallRe(GameLog log) => heuristicAgent.CallRe(log); + public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) + => heuristicAgent.IsKlopfer(position, firstFourCards); + public GameCall MakeCall( + ReadOnlySpan possibleCalls, + int position, Hand hand, int klopfer) + => heuristicAgent.MakeCall(possibleCalls, position, hand, klopfer); + public void OnGameFinished(GameLog final) => heuristicAgent.OnGameFinished(final); +} + public class PPOModel { public PPOModel(PPOTrainingSettings config) @@ -65,11 +147,13 @@ public void Train(PPORolloutBuffer memory) private void updateModels(PPOTrainBatch batch) { + // update strategy pi(s) var predPi = strategy.PredictBatch(batch.StatesBefore); var policyDeltas = strategy.Layers.Last().Cache.DeltasIn; computePolicyDeltas(batch, predPi, policyDeltas); strategy.FitBatch(policyDeltas, strategyOpt); + // update baseline V(s) var predV = valueFunc.PredictBatch(batch.StatesBefore); var valueDeltas = valueFunc.Layers.Last().Cache.DeltasIn; computeValueDeltas(batch, predV, valueDeltas); @@ -140,29 +224,44 @@ private IEnumerable onehotIndices(Matrix2D sparseClassIds, int numClasses) } } -public class PPOTrainingSettings +public class PossibleCardPicker { - public int NumObsFeatures { get; set; } - public int TotalSteps = 10_000_000; - public double LearnRate = 3e-4; - public double RewardDiscount = 0.99; - public double GAEDiscount = 0.95; - public double ProbClip = 0.2; - public double ValueClip = 0.2; - public double VFCoef = 0.5; - public double EntCoef = 0.01; - public bool NormAdvantages = true; - public bool ClipValues = true; - public int BatchSize = 64; - public int NumEnvs = 32; - public int NumStateDims = 90; - public int NumActionDims = 32; - public int StepsPerUpdate = 512; - public int UpdateEpochs = 4; - public int NumModelSnapshots = 20; + private UniformDistribution uniform = new UniformDistribution(); - public int TrainSteps => TotalSteps / NumEnvs; - public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots; + public Card PickCard( + ReadOnlySpan possibleCards, + ReadOnlySpan predPi, + Card sampledCard) + => canPlaySampledCard(possibleCards, sampledCard) ? sampledCard + : possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + + public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) + => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + + private bool canPlaySampledCard( + ReadOnlySpan possibleCards, Card sampledCard) + { + foreach (var card in possibleCards) + if (card == sampledCard) + return true; + return false; + } + + private double[] probDistCache = new double[8]; + private ReadOnlySpan normProbDist( + ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) + { + double probSum = 0; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; + for (int i = 0; i < possibleCards.Length; i++) + probSum += probDistCache[i]; + double scale = 1 / probSum; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] *= scale; + + return probDistCache.AsSpan().Slice(0, possibleCards.Length); + } } public struct PPOTrainBatch diff --git a/Schafkopf.Training/GameState.cs b/Schafkopf.Training/GameState.cs index 53f6fb7..cb32f9d 100644 --- a/Schafkopf.Training/GameState.cs +++ b/Schafkopf.Training/GameState.cs @@ -85,7 +85,7 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip }; int t = 0; - for (int t_id = 0; t_id < Math.Ceiling((double)timesteps / 4); t_id++) + for (int t_id = 0; t_id < Math.Min(timesteps / 4 + 1, 8); t_id++) { scores.MoveNext(); @@ -95,13 +95,15 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip if (t >= skip) { + int actingPlayer = t == history.CardCount + ? history.DrawingPlayerId : allActions[t].PlayerId; var hand = hands.Current; var score = scores.Current; var state = statesCache[t].State; - serializeState(state, normCalls, hand, t++, allActions, score); + serializeState(state, normCalls, hand, t, actingPlayer, allActions, score); } - if (t == timesteps) return; + if (t++ == history.CardCount) return; // TODO: check this condition } } @@ -110,12 +112,12 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip for (; t < 36; t++) serializeState( statesCache[t].State, normCalls, Hand.EMPTY, - t, allActions, scores.Current); + t, t % 4, allActions, scores.Current); } private unsafe void serializeState( double[] state, ReadOnlySpan normCalls, Hand hand, int t, - ReadOnlySpan turnHistory, int[] augen) + int actingPlayer, ReadOnlySpan turnHistory, int[] augen) { if (state.Length < 90) throw new IndexOutOfRangeException("Memory overflow"); @@ -126,7 +128,7 @@ private unsafe void serializeState( // - turn history (64 floats) // - augen (4 floats) - int actingPlayer = t < 32 ? turnHistory[t].PlayerId : t & 0x3; + // int actingPlayer = t < 32 ? (t == 0 ? kommtRaus : turnHistory[t].PlayerId) : t % 4; var call = normCalls[actingPlayer]; fixed (double* stateArr = &state[0]) @@ -250,17 +252,14 @@ public static class GameLogEx { public static IEnumerable UnrollActions(this GameLog log) { - var turnCache = new Card[4]; var action = new GameAction(); foreach (var turn in log.Turns) { int p_id = turn.FirstDrawingPlayerId; - turn.CopyCards(turnCache); - for (int i = 0; i < turn.CardsCount; i++) + foreach (var card in turn.AllCards) { - var card = turnCache[p_id]; action.PlayerId = (byte)p_id; action.CardPlayed = card; yield return action; diff --git a/Schafkopf.Training/Program.cs b/Schafkopf.Training/Program.cs index f81dd25..94bf5a8 100644 --- a/Schafkopf.Training/Program.cs +++ b/Schafkopf.Training/Program.cs @@ -2,28 +2,10 @@ public class Program { - private static FFModel createModel() - => new FFModel(new ILayer[] { - new DenseLayer(64), - new ReLULayer(), - new DenseLayer(64), - new ReLULayer(), - new DenseLayer(1), - }); - public static void Main(string[] args) { - var model = createModel(); - var dataset = SupervisedSchafkopfDataset.GenerateDataset( - trainSize: 1_000_000, testSize: 10_000); - var optimizer = new AdamOpt(learnRate: 0.002); - var lossFunc = new MeanSquaredError(); - - var session = new SupervisedTrainingSession( - model, optimizer, lossFunc, dataset); - Console.WriteLine("Training started!"); - Console.WriteLine($"loss before: loss={session.Eval()}"); - session.Train(5, true, (ep, l) => Console.WriteLine($"loss ep. {ep}: loss={l}")); - Console.WriteLine("Training finished!"); + var config = new PPOTrainingSettings(); + var session = new PPOTrainingSession(); + session.Train(config); } } diff --git a/Schafkopf.Training/RandomPlayBenchmark.cs b/Schafkopf.Training/RandomPlayBenchmark.cs index 339a389..51c3a08 100644 --- a/Schafkopf.Training/RandomPlayBenchmark.cs +++ b/Schafkopf.Training/RandomPlayBenchmark.cs @@ -2,7 +2,7 @@ namespace Schafkopf.Training; public class RandomPlayBenchmark { - public void Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000) + public double Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000) { var gameCaller = new HeuristicGameCaller( new GameMode[] { GameMode.Sauspiel, GameMode.Wenz, GameMode.Solo }); @@ -22,13 +22,16 @@ public void Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000) for (int i = 0; i < epochs; i++) { var log = session.ProcessGame(); + + // info: only evaluate games where cards were played + if (log.Call.Mode == GameMode.Weiter) { i--; continue; } + var eval = new GameScoreEvaluation(log); bool isCaller = log.CallerIds.Contains(0); bool isWin = !eval.DidCallerWin ^ isCaller; wins += isWin ? 1 : 0; } - double winRate = (double)wins / epochs; - Console.WriteLine($"agent scored a win rate of {winRate}!"); + return (double)wins / epochs; // win rate } }