From 6fa71c532cd91d41a14d85734f2b0e1d654cdd92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Sun, 3 Dec 2023 21:16:29 +0100 Subject: [PATCH] fix training data collection (wip) --- Schafkopf.Training.Tests/EnvTests.cs | 49 +++ .../FeatureVectorTests.cs | 358 +++++++++--------- Schafkopf.Training/Algos/MDP.cs | 35 +- Schafkopf.Training/Algos/PPOAgent.cs | 5 +- 4 files changed, 251 insertions(+), 196 deletions(-) create mode 100644 Schafkopf.Training.Tests/EnvTests.cs diff --git a/Schafkopf.Training.Tests/EnvTests.cs b/Schafkopf.Training.Tests/EnvTests.cs new file mode 100644 index 0000000..304ad25 --- /dev/null +++ b/Schafkopf.Training.Tests/EnvTests.cs @@ -0,0 +1,49 @@ +using Schafkopf.Lib; + +namespace Schafkopf.Training.Tests; + +public class CardPickerEnvTests +{ + [Fact] + public void Test_CanPlayGame() + { + var rules = new GameRules(); + var cardCache = new Card[8]; + var rng = new Random(); + + var env = new CardPickerEnv(); + var state = env.Reset(); + foreach (int i in Enumerable.Range(0, 32)) + { + var possActions = rules.PossibleCards(state, cardCache); + var action = possActions[rng.Next(possActions.Length)]; + (state, var __, var ___) = env.Step(action); + Assert.Equal(i+1, state.CardCount); + } + + Assert.Equal(32, state.CardCount); // assert that no exception occurred + } + + [Fact(Skip="raises error (needs to be fixed)")] + public void Test_CanPlayConsequtiveGames() + { + var rules = new GameRules(); + var cardCache = new Card[8]; + var rng = new Random(); + var env = new CardPickerEnv(); + + foreach (int _ in Enumerable.Range(0, 1000)) + { + var state = env.Reset(); + foreach (int i in Enumerable.Range(0, 32)) + { + var possActions = rules.PossibleCards(state, cardCache); + var action = possActions[rng.Next(possActions.Length)]; + (state, var __, var ___) = env.Step(action); + Assert.Equal(i+1, state.CardCount); + } + + Assert.Equal(32, state.CardCount); // assert that no exception occurred + } + } +} diff --git a/Schafkopf.Training.Tests/FeatureVectorTests.cs b/Schafkopf.Training.Tests/FeatureVectorTests.cs index 0a5acfa..a2d41c4 100644 --- a/Schafkopf.Training.Tests/FeatureVectorTests.cs +++ b/Schafkopf.Training.Tests/FeatureVectorTests.cs @@ -1,175 +1,183 @@ -// using Schafkopf.Lib; - -// namespace Schafkopf.Training.Tests; - -// public class FeatureVectorTests -// { -// #region HistoryGenerator - -// private Turn[] playRandomGame(GameCall call, Hand[] initialHands) -// { -// var gameRules = new GameRules(); -// var handsWithMeta = initialHands -// .Select(h => h.CacheTrumpf(call.IsTrumpf)).ToArray(); - -// int p_id = 0; -// var history = new Turn[8]; -// var turn = Turn.InitFirstTurn(0, call); -// for (int t_id = 0; t_id < 7; t_id++) -// { -// for (int i = 0; i < 4; i++) -// { -// var hand = handsWithMeta[p_id]; -// var card = hand.Where(c => gameRules.CanPlayCard(call, c, turn, hand)).First(); -// turn = turn.NextCard(card); -// handsWithMeta[p_id] = hand.Discard(card); -// p_id = (p_id + 1) % 4; -// } -// history[t_id] = turn; -// p_id = turn.WinnerId; -// turn = Turn.InitNextTurn(turn); -// } - -// for (int i = 0; i < 4; i++) -// { -// var card = handsWithMeta[p_id].First(); -// turn = turn.NextCard(card); -// p_id = (p_id + 1) % 4; -// } -// history[7] = turn; - -// return history; -// } - -// private GameLog generateHistoryWithCall(GameCall expCall) -// { -// var deck = new CardsDeck(); -// var callGen = new GameCallGenerator(); -// GameCall[] possCalls; -// Hand[] initialHands; - -// do { -// deck.Shuffle(); -// initialHands = deck.ToArray(); -// possCalls = callGen.AllPossibleCalls(0, initialHands, GameCall.Weiter()).ToArray(); -// possCalls.Contains(expCall); -// } while (!possCalls.Contains(expCall)); - -// var history = playRandomGame(expCall, initialHands); - -// return new GameLog() { -// Call = expCall, -// InitialHands = initialHands, -// Turns = history -// }; -// } - -// #endregion HistoryGenerator - -// [Fact(Skip = "code is not ready yet")] -// public void Test_CanEncodeSauspielCall() -// { -// // TODO: transform this into a theory with multiple calls - -// var serializer = new GameStateSerializer(); -// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); -// var history = generateHistoryWithCall(call); - -// var states = serializer.NewBuffer(); -// serializer.Serialize(history, states); - -// Assert.True(states.All(x => x.State[0] == 0.25)); -// Assert.True(states.All(x => x.State[1] == 0)); -// Assert.True(states.All(x => x.State[2] == 0)); -// Assert.True(states.All(x => x.State[3] == 0.25)); -// Assert.True(states.All(x => x.State[4] == 0.25)); -// Assert.True(states.All(x => x.State[5] == 0)); -// } - -// private void assertValidHandEncoding(GameState state, Hand hand) -// { -// int p = 6; -// var cards = hand.ToArray(); - -// for (int i = 0; i < cards.Length; i++) -// { -// Assert.Equal((double)cards[i].Type / 8, state.State[p++]); -// Assert.Equal((double)cards[i].Color / 4, state.State[p++]); -// } - -// for (int i = cards.Length; i < 8; i++) -// { -// Assert.Equal(-1, state.State[p++]); -// Assert.Equal(-1, state.State[p++]); -// } -// } - -// [Fact(Skip = "code is not ready yet")] -// public void Test_CanEncodeHands() -// { -// var serializer = new GameStateSerializer(); -// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); -// var history = generateHistoryWithCall(call); - -// var states = serializer.NewBuffer(); -// serializer.Serialize(history, states); - -// foreach ((var hand, var state) in history.UnrollHands().Zip(states)) -// assertValidHandEncoding(state, hand); -// } - -// private void assertValidTurnHistory( -// GameState state, ReadOnlySpan history, int t) -// { -// int p = 22; - -// for (int i = 0; i < t; i++) -// { -// var cardPlayed = history[i].CardPlayed; -// Assert.Equal((double)cardPlayed.Type / 8, state.State[p++]); -// Assert.Equal((double)cardPlayed.Color / 4, state.State[p++]); -// } - -// for (int i = t; i < 32; i++) -// { -// Assert.Equal(-1, state.State[p++]); -// Assert.Equal(-1, state.State[p++]); -// } -// } - -// [Fact(Skip = "code is not ready yet")] -// public void Test_CanEncodeTurnHistory() -// { -// var serializer = new GameStateSerializer(); -// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); -// var history = generateHistoryWithCall(call); -// var allActions = history.UnrollActions().ToArray(); - -// var states = serializer.NewBuffer(); -// serializer.Serialize(history, states); - -// foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) -// assertValidTurnHistory(state, allActions, t); -// } - -// private void assertValidAugen(GameState state, int[] augen) -// { -// for (int i = 0; i < 4; i++) -// Assert.Equal((double)augen[i] / 120, state.State[i+86]); -// } - -// [Fact(Skip = "code is not ready yet")] -// public void Test_CanEncodeAugen() -// { -// var serializer = new GameStateSerializer(); -// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); -// var history = generateHistoryWithCall(call); -// var allAugen = history.UnrollAugen().Select(x => x.ToArray()).ToArray(); - -// var states = serializer.NewBuffer(); -// serializer.Serialize(history, states); - -// foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) -// assertValidAugen(state, allAugen[t / 4]); -// } -// } +using Schafkopf.Lib; + +namespace Schafkopf.Training.Tests; + +public class FeatureVectorTests +{ + [Fact] + public void Test_CanSerializeCompleteGame() + { + var serializer = new GameStateSerializer(); + var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + var history = generateHistoryWithCall(call); + + var newExp = () => new SarsExp() { StateBefore = new GameState() }; + var states = Enumerable.Range(0, 32).Select(i => newExp()).ToArray(); + serializer.SerializeSarsExps(history, states); + + Assert.True(true); // serialization does not throw exception + } + + [Fact] + public void Test_CanSerializeLiveGame() + { + var serializer = new GameStateSerializer(); + var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + var completeGame = generateHistoryWithCall(call); + var actions = completeGame.UnrollActions().ToArray(); + + int kommtRaus = completeGame.Turns[0].FirstDrawingPlayerId; + var liveGame = GameLog.NewLiveGame( + call, completeGame.InitialHands, kommtRaus); + + foreach (var action in actions) + { + serializer.SerializeState(liveGame); + liveGame.NextCard(action.CardPlayed); + } + + Assert.True(true); // serialization does not throw exception + } + + #region HistoryGenerator + + private GameLog playRandomGame(GameCall call, Hand[] initialHands) + { + var gameRules = new GameRules(); + var liveGame = GameLog.NewLiveGame(call, initialHands, 0); + var cardsCache = new Card[8]; + + foreach (var _ in Enumerable.Range(0, 32)) + liveGame.NextCard(gameRules.PossibleCards( + liveGame, cardsCache).ToArray().First()); + + return liveGame; + } + + private GameLog generateHistoryWithCall(GameCall expCall) + { + var deck = new CardsDeck(); + var callGen = new GameCallGenerator(); + GameCall[] possCalls; + Hand[] initialHands; + + do { + deck.Shuffle(); + initialHands = deck.ToArray(); + possCalls = callGen.AllPossibleCalls( + 0, initialHands, GameCall.Weiter()).ToArray(); + possCalls.Contains(expCall); + } while (!possCalls.Contains(expCall)); + + return playRandomGame(expCall, initialHands); + } + + #endregion HistoryGenerator + + // [Fact(Skip = "code is not ready yet")] + // public void Test_CanEncodeSauspielCall() + // { + // // TODO: transform this into a theory with multiple calls + + // var serializer = new GameStateSerializer(); + // var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + // var history = generateHistoryWithCall(call); + + // var states = serializer.NewBuffer(); + // serializer.Serialize(history, states); + + // Assert.True(states.All(x => x.State[0] == 0.25)); + // Assert.True(states.All(x => x.State[1] == 0)); + // Assert.True(states.All(x => x.State[2] == 0)); + // Assert.True(states.All(x => x.State[3] == 0.25)); + // Assert.True(states.All(x => x.State[4] == 0.25)); + // Assert.True(states.All(x => x.State[5] == 0)); + // } + + // private void assertValidHandEncoding(GameState state, Hand hand) + // { + // int p = 6; + // var cards = hand.ToArray(); + + // for (int i = 0; i < cards.Length; i++) + // { + // Assert.Equal((double)cards[i].Type / 8, state.State[p++]); + // Assert.Equal((double)cards[i].Color / 4, state.State[p++]); + // } + + // for (int i = cards.Length; i < 8; i++) + // { + // Assert.Equal(-1, state.State[p++]); + // Assert.Equal(-1, state.State[p++]); + // } + // } + + // [Fact(Skip = "code is not ready yet")] + // public void Test_CanEncodeHands() + // { + // var serializer = new GameStateSerializer(); + // var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + // var history = generateHistoryWithCall(call); + + // var states = serializer.NewBuffer(); + // serializer.Serialize(history, states); + + // foreach ((var hand, var state) in history.UnrollHands().Zip(states)) + // assertValidHandEncoding(state, hand); + // } + + // private void assertValidTurnHistory( + // GameState state, ReadOnlySpan history, int t) + // { + // int p = 22; + + // for (int i = 0; i < t; i++) + // { + // var cardPlayed = history[i].CardPlayed; + // Assert.Equal((double)cardPlayed.Type / 8, state.State[p++]); + // Assert.Equal((double)cardPlayed.Color / 4, state.State[p++]); + // } + + // for (int i = t; i < 32; i++) + // { + // Assert.Equal(-1, state.State[p++]); + // Assert.Equal(-1, state.State[p++]); + // } + // } + + // [Fact(Skip = "code is not ready yet")] + // public void Test_CanEncodeTurnHistory() + // { + // var serializer = new GameStateSerializer(); + // var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + // var history = generateHistoryWithCall(call); + // var allActions = history.UnrollActions().ToArray(); + + // var states = serializer.NewBuffer(); + // serializer.Serialize(history, states); + + // foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) + // assertValidTurnHistory(state, allActions, t); + // } + + // private void assertValidAugen(GameState state, int[] augen) + // { + // for (int i = 0; i < 4; i++) + // Assert.Equal((double)augen[i] / 120, state.State[i+86]); + // } + + // [Fact(Skip = "code is not ready yet")] + // public void Test_CanEncodeAugen() + // { + // var serializer = new GameStateSerializer(); + // var call = GameCall.Sauspiel(0, 1, CardColor.Schell); + // var history = generateHistoryWithCall(call); + // var allAugen = history.UnrollAugen().Select(x => x.ToArray()).ToArray(); + + // var states = serializer.NewBuffer(); + // serializer.Serialize(history, states); + + // foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) + // assertValidAugen(state, allAugen[t / 4]); + // } +} diff --git a/Schafkopf.Training/Algos/MDP.cs b/Schafkopf.Training/Algos/MDP.cs index f2324fd..ed72ba9 100644 --- a/Schafkopf.Training/Algos/MDP.cs +++ b/Schafkopf.Training/Algos/MDP.cs @@ -21,20 +21,25 @@ public void Collect(PPORolloutBuffer buffer) // throw new ArgumentException("The number of steps needs to be " // + "divisible by 8 because each agent plays 8 cards per game!"); + Console.Write($"collect data"); + int numGames = buffer.Steps / 8; int numSessions = buffer.NumEnvs / 4; var envs = Enumerable.Range(0, numSessions) .Select(i => new CardPickerEnv()).ToArray(); var states = envs.Select(env => env.Reset()).ToArray(); var batchesOfTurns = Enumerable.Range(0, 8) - .Select(i => new TurnBatches(numSessions)).ToArray(); + .Select(i => new TurnBatches(buffer.NumEnvs)).ToArray(); var rewards = Matrix2D.Zeros(8, buffer.NumEnvs); for (int gameId = 0; gameId < numGames + 1; gameId++) { + Console.Write($"\rcollecting ppo training data {(gameId)} / { numGames } ... "); playGame(envs, states, batchesOfTurns); prepareRewards(states, rewards); fillBuffer(gameId, buffer, states, batchesOfTurns, rewards); + for (int i = 0; i < states.Length; i++) + states[i] = envs[i].Reset(); } } @@ -130,11 +135,8 @@ private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batc for (int envId = 0; envId < envs.Length; envId++) { - // info: rewards and terminals are - // determined after the game is over - (var newState, double reward, bool isTerminal) = - envs[envId].Step(new Card((byte)actions[envId])); - states[envId] = newState; + var action = new Card((byte)actions[envId]); + states[envId] = envs[envId].Step(action).Item1; } } } @@ -177,15 +179,16 @@ public class CardPickerEnv public GameLog Reset() { kommtRaus = (kommtRaus + 1) % 4; - deck.Shuffle(); - deck.InitialHands(initialHandsCache); - // info: klopfer is not required to train a card picker - int klopfer = 0; // askForKlopfer(initialHandsCache); - var call = makeCalls(klopfer, initialHandsCache, kommtRaus); - log = GameLog.NewLiveGame(call, initialHandsCache, kommtRaus, klopfer); + GameCall call; int klopfer = 0; + do { + deck.Shuffle(); + deck.InitialHands(initialHandsCache); + call = makeCalls(klopfer, initialHandsCache, kommtRaus); + } + while (call.Mode == GameMode.Weiter); - return log; + return log = GameLog.NewLiveGame(call, initialHandsCache, kommtRaus, klopfer); } public (GameLog, double, bool) Step(Card cardToPlay) @@ -193,13 +196,7 @@ public GameLog Reset() if (log.CardCount >= 32) throw new InvalidOperationException("Game is already finished!"); - // info: kontra/re is not required to train a card picker - // if (log.CardCount <= 1) - // askForKontraRe(log); - log.NextCard(cardToPlay); - - // info: reward doesn't relate to the next state, compute it in calling scope return (log, 0.0, log.CardCount >= 28); } diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index af16fd7..dcdff8d 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -14,7 +14,7 @@ public class PPOTrainingSettings public bool NormAdvantages = true; public bool ClipValues = true; public int BatchSize = 64; - public int NumEnvs = 32; + public int NumEnvs = 64; public int NumStateDims = 90; public int NumActionDims = 32; public int StepsPerUpdate = 512; @@ -125,9 +125,10 @@ public PPOModel(PPOTrainingSettings config) private FFModel strategy; private IOptimizer strategyOpt; private IOptimizer valueFuncOpt; - private Matrix2D featureCache; + public int BatchSize => config.BatchSize; + public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV) { var predPi = strategy.PredictBatch(s0);