diff --git a/Schafkopf.Training.Tests/FeatureVectorTests.cs b/Schafkopf.Training.Tests/FeatureVectorTests.cs index f277987..0a5acfa 100644 --- a/Schafkopf.Training.Tests/FeatureVectorTests.cs +++ b/Schafkopf.Training.Tests/FeatureVectorTests.cs @@ -1,175 +1,175 @@ -using Schafkopf.Lib; - -namespace Schafkopf.Training.Tests; - -public class FeatureVectorTests -{ - #region HistoryGenerator - - private Turn[] playRandomGame(GameCall call, Hand[] initialHands) - { - var gameRules = new GameRules(); - var handsWithMeta = initialHands - .Select(h => h.CacheTrumpf(call.IsTrumpf)).ToArray(); - - int p_id = 0; - var history = new Turn[8]; - var turn = Turn.InitFirstTurn(0, call); - for (int t_id = 0; t_id < 7; t_id++) - { - for (int i = 0; i < 4; i++) - { - var hand = handsWithMeta[p_id]; - var card = hand.Where(c => gameRules.CanPlayCard(call, c, turn, hand)).First(); - turn = turn.NextCard(card); - handsWithMeta[p_id] = hand.Discard(card); - p_id = (p_id + 1) % 4; - } - history[t_id] = turn; - p_id = turn.WinnerId; - turn = Turn.InitNextTurn(turn); - } - - for (int i = 0; i < 4; i++) - { - var card = handsWithMeta[p_id].First(); - turn = turn.NextCard(card); - p_id = (p_id + 1) % 4; - } - history[7] = turn; - - return history; - } - - private GameLog generateHistoryWithCall(GameCall expCall) - { - var deck = new CardsDeck(); - var callGen = new GameCallGenerator(); - GameCall[] possCalls; - Hand[] initialHands; - - do { - deck.Shuffle(); - initialHands = deck.ToArray(); - possCalls = callGen.AllPossibleCalls(0, initialHands, GameCall.Weiter()).ToArray(); - possCalls.Contains(expCall); - } while (!possCalls.Contains(expCall)); - - var history = playRandomGame(expCall, initialHands); - - return new GameLog() { - Call = expCall, - InitialHands = initialHands, - Turns = history - }; - } - - #endregion HistoryGenerator - - [Fact(Skip = "code is not ready yet")] - public void Test_CanEncodeSauspielCall() - { - // TODO: transform this into a theory with multiple calls - - var serializer = new GameStateSerializer(); - var call = GameCall.Sauspiel(0, 1, CardColor.Schell); - var history = generateHistoryWithCall(call); - - var states = serializer.NewBuffer(); - serializer.Serialize(history, states); - - Assert.True(states.All(x => x.State[0] == 0.25)); - Assert.True(states.All(x => x.State[1] == 0)); - Assert.True(states.All(x => x.State[2] == 0)); - Assert.True(states.All(x => x.State[3] == 0.25)); - Assert.True(states.All(x => x.State[4] == 0.25)); - Assert.True(states.All(x => x.State[5] == 0)); - } - - private void assertValidHandEncoding(GameState state, Hand hand) - { - int p = 6; - var cards = hand.ToArray(); - - for (int i = 0; i < cards.Length; i++) - { - Assert.Equal((double)cards[i].Type / 8, state.State[p++]); - Assert.Equal((double)cards[i].Color / 4, state.State[p++]); - } - - for (int i = cards.Length; i < 8; i++) - { - Assert.Equal(-1, state.State[p++]); - Assert.Equal(-1, state.State[p++]); - } - } - - [Fact(Skip = "code is not ready yet")] - public void Test_CanEncodeHands() - { - var serializer = new GameStateSerializer(); - var call = GameCall.Sauspiel(0, 1, CardColor.Schell); - var history = generateHistoryWithCall(call); - - var states = serializer.NewBuffer(); - serializer.Serialize(history, states); - - foreach ((var hand, var state) in history.UnrollHands().Zip(states)) - assertValidHandEncoding(state, hand); - } - - private void assertValidTurnHistory( - GameState state, ReadOnlySpan history, int t) - { - int p = 22; - - for (int i = 0; i < t; i++) - { - var cardPlayed = history[i].CardPlayed; - Assert.Equal((double)cardPlayed.Type / 8, state.State[p++]); - Assert.Equal((double)cardPlayed.Color / 4, state.State[p++]); - } - - for (int i = t; i < 32; i++) - { - Assert.Equal(-1, state.State[p++]); - Assert.Equal(-1, state.State[p++]); - } - } - - [Fact(Skip = "code is not ready yet")] - public void Test_CanEncodeTurnHistory() - { - var serializer = new GameStateSerializer(); - var call = GameCall.Sauspiel(0, 1, CardColor.Schell); - var history = generateHistoryWithCall(call); - var allActions = history.UnrollActions().ToArray(); - - var states = serializer.NewBuffer(); - serializer.Serialize(history, states); - - foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) - assertValidTurnHistory(state, allActions, t); - } - - private void assertValidAugen(GameState state, int[] augen) - { - for (int i = 0; i < 4; i++) - Assert.Equal((double)augen[i] / 120, state.State[i+86]); - } - - [Fact(Skip = "code is not ready yet")] - public void Test_CanEncodeAugen() - { - var serializer = new GameStateSerializer(); - var call = GameCall.Sauspiel(0, 1, CardColor.Schell); - var history = generateHistoryWithCall(call); - var allAugen = history.UnrollAugen().Select(x => x.ToArray()).ToArray(); - - var states = serializer.NewBuffer(); - serializer.Serialize(history, states); - - foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) - assertValidAugen(state, allAugen[t / 4]); - } -} +// using Schafkopf.Lib; + +// namespace Schafkopf.Training.Tests; + +// public class FeatureVectorTests +// { +// #region HistoryGenerator + +// private Turn[] playRandomGame(GameCall call, Hand[] initialHands) +// { +// var gameRules = new GameRules(); +// var handsWithMeta = initialHands +// .Select(h => h.CacheTrumpf(call.IsTrumpf)).ToArray(); + +// int p_id = 0; +// var history = new Turn[8]; +// var turn = Turn.InitFirstTurn(0, call); +// for (int t_id = 0; t_id < 7; t_id++) +// { +// for (int i = 0; i < 4; i++) +// { +// var hand = handsWithMeta[p_id]; +// var card = hand.Where(c => gameRules.CanPlayCard(call, c, turn, hand)).First(); +// turn = turn.NextCard(card); +// handsWithMeta[p_id] = hand.Discard(card); +// p_id = (p_id + 1) % 4; +// } +// history[t_id] = turn; +// p_id = turn.WinnerId; +// turn = Turn.InitNextTurn(turn); +// } + +// for (int i = 0; i < 4; i++) +// { +// var card = handsWithMeta[p_id].First(); +// turn = turn.NextCard(card); +// p_id = (p_id + 1) % 4; +// } +// history[7] = turn; + +// return history; +// } + +// private GameLog generateHistoryWithCall(GameCall expCall) +// { +// var deck = new CardsDeck(); +// var callGen = new GameCallGenerator(); +// GameCall[] possCalls; +// Hand[] initialHands; + +// do { +// deck.Shuffle(); +// initialHands = deck.ToArray(); +// possCalls = callGen.AllPossibleCalls(0, initialHands, GameCall.Weiter()).ToArray(); +// possCalls.Contains(expCall); +// } while (!possCalls.Contains(expCall)); + +// var history = playRandomGame(expCall, initialHands); + +// return new GameLog() { +// Call = expCall, +// InitialHands = initialHands, +// Turns = history +// }; +// } + +// #endregion HistoryGenerator + +// [Fact(Skip = "code is not ready yet")] +// public void Test_CanEncodeSauspielCall() +// { +// // TODO: transform this into a theory with multiple calls + +// var serializer = new GameStateSerializer(); +// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); +// var history = generateHistoryWithCall(call); + +// var states = serializer.NewBuffer(); +// serializer.Serialize(history, states); + +// Assert.True(states.All(x => x.State[0] == 0.25)); +// Assert.True(states.All(x => x.State[1] == 0)); +// Assert.True(states.All(x => x.State[2] == 0)); +// Assert.True(states.All(x => x.State[3] == 0.25)); +// Assert.True(states.All(x => x.State[4] == 0.25)); +// Assert.True(states.All(x => x.State[5] == 0)); +// } + +// private void assertValidHandEncoding(GameState state, Hand hand) +// { +// int p = 6; +// var cards = hand.ToArray(); + +// for (int i = 0; i < cards.Length; i++) +// { +// Assert.Equal((double)cards[i].Type / 8, state.State[p++]); +// Assert.Equal((double)cards[i].Color / 4, state.State[p++]); +// } + +// for (int i = cards.Length; i < 8; i++) +// { +// Assert.Equal(-1, state.State[p++]); +// Assert.Equal(-1, state.State[p++]); +// } +// } + +// [Fact(Skip = "code is not ready yet")] +// public void Test_CanEncodeHands() +// { +// var serializer = new GameStateSerializer(); +// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); +// var history = generateHistoryWithCall(call); + +// var states = serializer.NewBuffer(); +// serializer.Serialize(history, states); + +// foreach ((var hand, var state) in history.UnrollHands().Zip(states)) +// assertValidHandEncoding(state, hand); +// } + +// private void assertValidTurnHistory( +// GameState state, ReadOnlySpan history, int t) +// { +// int p = 22; + +// for (int i = 0; i < t; i++) +// { +// var cardPlayed = history[i].CardPlayed; +// Assert.Equal((double)cardPlayed.Type / 8, state.State[p++]); +// Assert.Equal((double)cardPlayed.Color / 4, state.State[p++]); +// } + +// for (int i = t; i < 32; i++) +// { +// Assert.Equal(-1, state.State[p++]); +// Assert.Equal(-1, state.State[p++]); +// } +// } + +// [Fact(Skip = "code is not ready yet")] +// public void Test_CanEncodeTurnHistory() +// { +// var serializer = new GameStateSerializer(); +// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); +// var history = generateHistoryWithCall(call); +// var allActions = history.UnrollActions().ToArray(); + +// var states = serializer.NewBuffer(); +// serializer.Serialize(history, states); + +// foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) +// assertValidTurnHistory(state, allActions, t); +// } + +// private void assertValidAugen(GameState state, int[] augen) +// { +// for (int i = 0; i < 4; i++) +// Assert.Equal((double)augen[i] / 120, state.State[i+86]); +// } + +// [Fact(Skip = "code is not ready yet")] +// public void Test_CanEncodeAugen() +// { +// var serializer = new GameStateSerializer(); +// var call = GameCall.Sauspiel(0, 1, CardColor.Schell); +// var history = generateHistoryWithCall(call); +// var allAugen = history.UnrollAugen().Select(x => x.ToArray()).ToArray(); + +// var states = serializer.NewBuffer(); +// serializer.Serialize(history, states); + +// foreach ((int t, var state) in Enumerable.Range(0, 33).Zip(states)) +// assertValidAugen(state, allAugen[t / 4]); +// } +// } diff --git a/Schafkopf.Training/GameState.cs b/Schafkopf.Training/GameState.cs index 26526e6..d88f9fd 100644 --- a/Schafkopf.Training/GameState.cs +++ b/Schafkopf.Training/GameState.cs @@ -26,43 +26,45 @@ public class GameStateSerializer { private const int NO_CARD = -1; - public GameState[] NewBuffer() + public static GameState[] NewBuffer() => Enumerable.Range(0, 36).Select(x => new GameState()).ToArray(); - private GameState[] stateBuffer = Enumerable.Range(0, 36) - .Select(x => new GameState()).ToArray(); + private GameState[] stateBuffer = NewBuffer(); public void SerializeSarsExps( - GameLog completedGame, SarsExp[] exps, Func reward) + GameLog completedGame, SarsExp[] exps, + Func reward) { - Serialize(completedGame, stateBuffer); - var actions = completedGame.UnrollActions().ToArray(); - var p_ids = completedGame.UnrollActions() - .Select(x => x.PlayerId).ToArray(); + serializeHistory(completedGame, stateBuffer); - var order_idx = new byte[4, 8]; - foreach ((byte p_id, int t) in p_ids.Zip(Enumerable.Range(0, 32))) - order_idx[p_id, t / 4] = (byte)(t % 4); - - for (int t = 0; t < 32; t++) + var actions = completedGame.UnrollActions().GetEnumerator(); + for (int t0 = 0; t0 < 32; t0++) { - int p_id = p_ids[t]; - int t_id = t / 4; - bool isTerminal = t > 28; - int p0 = t_id * 4 + order_idx[p_id, t_id]; - int p1 = (t_id + 1) * 4 + order_idx[p_id, t_id + 1]; - p1 = isTerminal ? 32 + p_id : p1; - - exps[t].Action.PlayerId = 0; - exps[t].Action.CardPlayed = actions[t].CardPlayed; - exps[t].StateBefore.LoadFeatures(stateBuffer[p0].State); - exps[t].StateAfter.LoadFeatures(stateBuffer[p1].State); - exps[t].IsTerminal = isTerminal; - exps[t].Reward = reward(stateBuffer[p1]); + actions.MoveNext(); + var card = actions.Current.CardPlayed; + int p_id = actions.Current.PlayerId; + int t_id = t0 / 4; + bool isTerminal = t0 >= 28; + int t1 = playerPosOfTurn(completedGame, t_id+1, p_id); + + exps[t0].Action.PlayerId = 0; + exps[t0].Action.CardPlayed = card; + exps[t0].StateBefore.LoadFeatures(stateBuffer[t0].State); + exps[t0].StateAfter.LoadFeatures(stateBuffer[t1].State); + exps[t0].IsTerminal = isTerminal; + exps[t0].Reward = reward(completedGame, t1); } } - public void Serialize(GameLog completedGame, GameState[] states) + private int playerPosOfTurn(GameLog log, int t_id, int p_id) + => t_id == 8 ? p_id : normPlayerId(p_id, log.Turns[t_id].FirstDrawingPlayerId); + + private void serializeHistory(GameLog completedGame, GameState[] statesCache) { + if (completedGame.TurnCount != 32) + throw new ArgumentException("Can only process finished games!"); + if (statesCache.Length < 36) + throw new ArgumentException(""); + var origCall = completedGame.Call; var hands = completedGame.UnrollHands().GetEnumerator(); var scores = completedGame.UnrollAugen().GetEnumerator(); @@ -83,7 +85,7 @@ public void Serialize(GameLog completedGame, GameState[] states) var hand = hands.Current; var score = scores.Current; - var state = states[t].State; + var state = statesCache[t].State; serializeState(state, normCalls, hand, t++, allActions, score); } } @@ -92,7 +94,7 @@ public void Serialize(GameLog completedGame, GameState[] states) for (; t < 36; t++) serializeState( - states[t].State, normCalls, Hand.EMPTY, + statesCache[t].State, normCalls, Hand.EMPTY, t, allActions, scores.Current); } @@ -228,12 +230,17 @@ private int normPlayerId(int id, int offset) public class GameReward { - public double Reward(GameLog log, int playerId) + public static double Reward(GameLog log, int t) { // intention of this reward system: // - players receive reward 1 as soon as they are in a winning state // - if they are in a losing or undetermined state, they receive reward 0 + // info: t >= 32 relate to the final game outcome + // from the view of the player with p_id = t%4 + int playerId = t >= 32 ? t % 4 : + normPlayerId(log.Turns[t / 4].FirstDrawingPlayerId, t % 4); + // info: players don't know yet who the sauspiel partner is // -> no reward, even if it's already won var currentTurn = log.Turns[log.CardCount / 4]; @@ -242,12 +249,16 @@ public double Reward(GameLog log, int playerId) bool isCaller = log.CallerIds.Contains(playerId); var augen = log.UnrollAugen().Last(); - double callerScore = log.CallerIds.ToArray() - .Select(i => augen[i]).Sum(); + double callerScore = 0; + for (int i = 0; i < log.CallerIds.Length; i++) + callerScore += augen[log.CallerIds[i]]; if (log.Call.Mode != GameMode.Sauspiel && log.Call.IsTout) return isCaller && callerScore == 120 ? 1 : 0; else return (isCaller && callerScore >= 61) || !isCaller ? 1 : 0; } + + private static int normPlayerId(int id, int offset) + => (id - offset + 4) & 0x03; }