From 094b2f99c6a566cbeae0baf37f9e3e4825ae570f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Tue, 28 Nov 2023 01:28:12 +0100 Subject: [PATCH] unroll rewards from game history --- Schafkopf.Lib/GameLog.cs | 47 --------- Schafkopf.Lib/GameResult.cs | 5 +- Schafkopf.Training/Dataset.cs | 3 +- Schafkopf.Training/GameState.cs | 165 +++++++++++++++++++++++++------- 4 files changed, 137 insertions(+), 83 deletions(-) diff --git a/Schafkopf.Lib/GameLog.cs b/Schafkopf.Lib/GameLog.cs index 9ce4a0c..ea07b32 100644 --- a/Schafkopf.Lib/GameLog.cs +++ b/Schafkopf.Lib/GameLog.cs @@ -151,53 +151,6 @@ public Turn NextCard(Card card) else return Turns[t_id]; } - - public IEnumerable UnrollActions() - { - var turnCache = new Card[4]; - var action = new GameAction(); - - foreach (var turn in Turns) - { - int p_id = turn.FirstDrawingPlayerId; - turn.CopyCards(turnCache); - - for (int i = 0; i < turn.CardsCount; i++) - { - var card = turnCache[p_id]; - action.PlayerId = (byte)p_id; - action.CardPlayed = card; - yield return action; - p_id = (p_id + 1) % 4; - } - } - } - - public IEnumerable UnrollHands() - { - int i = 0; - var hands = InitialHands.ToArray(); - foreach (var action in UnrollActions()) - { - if (i++ >= CardCount) - break; - yield return hands[action.PlayerId]; - hands[action.PlayerId] = hands[action.PlayerId].Discard(action.CardPlayed); - } - if (CardCount == 32) - yield return Hand.EMPTY; - } - - public IEnumerable UnrollAugen() - { - var augen = new int[4]; - foreach (var turn in Turns) - { - yield return augen; - augen[turn.WinnerId] += turn.Augen; - } - yield return augen; - } } public struct GameAction : IEquatable diff --git a/Schafkopf.Lib/GameResult.cs b/Schafkopf.Lib/GameResult.cs index 3af3db3..55fb23c 100644 --- a/Schafkopf.Lib/GameResult.cs +++ b/Schafkopf.Lib/GameResult.cs @@ -56,7 +56,10 @@ public class GameScoreEvaluation { public GameScoreEvaluation(GameLog log) { - var augen = log.UnrollAugen().Last(); + var augen = new int[4]; + foreach (var turn in log.Turns) + augen[turn.WinnerId] += turn.Augen; + ScoreCaller = log.CallerIds.ToArray().Select(id => augen[id]).Sum(); ScoreOpponents = 120 - ScoreCaller; diff --git a/Schafkopf.Training/Dataset.cs b/Schafkopf.Training/Dataset.cs index b20602b..11297b5 100644 --- a/Schafkopf.Training/Dataset.cs +++ b/Schafkopf.Training/Dataset.cs @@ -61,8 +61,7 @@ private static IEnumerable generateExperiences( if (log.Call.Mode == GameMode.Weiter) continue; - serializer.SerializeSarsExps( - log, expBuffer, GameReward.Reward); + serializer.SerializeSarsExps(log, expBuffer); for (int i = 0; i < 32; i++) if (numExamples == null || p++ < numExamples) diff --git a/Schafkopf.Training/GameState.cs b/Schafkopf.Training/GameState.cs index 9309e0d..185350e 100644 --- a/Schafkopf.Training/GameState.cs +++ b/Schafkopf.Training/GameState.cs @@ -36,17 +36,17 @@ public static GameState[] NewBuffer() => Enumerable.Range(0, 36).Select(x => new GameState()).ToArray(); private GameState[] stateBuffer = NewBuffer(); - public void SerializeSarsExps( - GameLog completedGame, SarsExp[] exps, - Func reward) + public void SerializeSarsExps(GameLog completedGame, SarsExp[] exps) { if (completedGame.CardCount != 32) throw new ArgumentException("Can only process finished games!"); serializeHistory(completedGame, stateBuffer); var actions = completedGame.UnrollActions().GetEnumerator(); + var rewards = completedGame.UnrollRewards().GetEnumerator(); for (int t0 = 0; t0 < 32; t0++) { + rewards.MoveNext(); actions.MoveNext(); var card = actions.Current.CardPlayed; int p_id = actions.Current.PlayerId; @@ -58,7 +58,7 @@ public void SerializeSarsExps( exps[t0].StateBefore.LoadFeatures(stateBuffer[t0].State); exps[t0].StateAfter.LoadFeatures(stateBuffer[t1].State); exps[t0].IsTerminal = isTerminal; - exps[t0].Reward = reward(completedGame, t1); + exps[t0].Reward = rewards.Current.Item2; } } @@ -246,37 +246,136 @@ public static unsafe int EncodeOnehot(double* stateArr, bool flag) } } -public class GameReward +public static class GameLogEx { - public static double Reward(GameLog log, int t) + public static IEnumerable UnrollActions(this GameLog log) { - // intention of this reward system: - // - players receive reward 1 as soon as they are in a winning state - // - if they are in a losing or undetermined state, they receive reward 0 - - // info: t >= 32 relate to the final game outcome - // from the view of the player with p_id = t%4 - int playerId = t >= 32 ? t % 4 : - normPlayerId(log.Turns[t / 4].FirstDrawingPlayerId, t % 4); - - // info: players don't know yet who the sauspiel partner is - // -> no reward, even if it's already won - bool alreadyGsucht = t < 28 ? log.Turns[t / 4].AlreadyGsucht : true; - if (log.Call.Mode == GameMode.Sauspiel && !alreadyGsucht) - return 0; - - bool isCaller = log.CallerIds.Contains(playerId); - var augen = log.UnrollAugen().Last(); - double callerScore = 0; - for (int i = 0; i < log.CallerIds.Length; i++) - callerScore += augen[log.CallerIds[i]]; - - if (log.Call.Mode != GameMode.Sauspiel && log.Call.IsTout) - return isCaller && callerScore == 120 ? 1 : 0; - else - return (isCaller && callerScore >= 61) || !isCaller ? 1 : 0; + var turnCache = new Card[4]; + var action = new GameAction(); + + foreach (var turn in log.Turns) + { + int p_id = turn.FirstDrawingPlayerId; + turn.CopyCards(turnCache); + + for (int i = 0; i < turn.CardsCount; i++) + { + var card = turnCache[p_id]; + action.PlayerId = (byte)p_id; + action.CardPlayed = card; + yield return action; + p_id = (p_id + 1) % 4; + } + } } - private static int normPlayerId(int id, int offset) - => (id - offset + 4) & 0x03; + public static IEnumerable UnrollActingPlayers(this GameLog log) + { + foreach (var turn in log.Turns) + { + int p_id = turn.FirstDrawingPlayerId; + + for (int i = 0; i < turn.CardsCount; i++) + { + yield return p_id; + p_id = (p_id + 1) % 4; + } + } + } + + public static IEnumerable UnrollHands(this GameLog log) + { + int i = 0; + var hands = log.InitialHands.ToArray(); + foreach (var action in log.UnrollActions()) + { + if (i++ >= log.CardCount) + break; + yield return hands[action.PlayerId]; + hands[action.PlayerId] = hands[action.PlayerId].Discard(action.CardPlayed); + } + if (log.CardCount == 32) + yield return Hand.EMPTY; + } + + public static IEnumerable UnrollAugen(this GameLog log) + { + var augen = new int[4]; + foreach (var turn in log.Turns) + { + yield return augen; + augen[turn.WinnerId] += turn.Augen; + } + yield return augen; + } +} + +public static class GameReward +{ + public static IEnumerable<(int, double)> UnrollRewards(this GameLog completeGame) + { + int callerId = completeGame.Call.CallingPlayerId; + int partnerId = completeGame.Call.PartnerPlayerId; + var oppIds = completeGame.OpponentIds.ToArray(); + var augenIter = completeGame.UnrollAugen().GetEnumerator(); + augenIter.MoveNext(); + + int t = 0; + foreach (var action in completeGame.UnrollActions()) + { + if (t % 4 == 0) + augenIter.MoveNext(); + var augen = augenIter.Current; + + if (completeGame.Call.Mode == GameMode.Sauspiel) + { + int p_id = action.PlayerId; + bool isCaller = p_id == callerId; + bool isPartner = p_id == partnerId; + + int ownAugen = augen[p_id]; + int partnerAugen; + if (isCaller) + partnerAugen = augen[partnerId]; + else if (isPartner) + partnerAugen = augen[callerId]; + else if (p_id == oppIds[0]) + partnerAugen = augen[oppIds[1]]; + else // if (p_id == oppIds[0]) + partnerAugen = augen[oppIds[0]]; + + bool knowsPartner = isPartner || completeGame.Turns[t / 4].AlreadyGsucht; + double reward = rewardSauspiel( + ownAugen, partnerAugen, isCaller || isPartner, knowsPartner); + yield return (p_id, reward); + } + else // Wenz or Solo + { + int p_id = action.PlayerId; + int callerAugen = augen[callerId]; + int opponentAugen = augen.Sum() - callerAugen; + bool isCaller = p_id == callerId; + bool isTout = completeGame.Call.IsTout; + double reward = rewardSoloWenz(callerAugen, opponentAugen, isCaller, isTout); + yield return (p_id, reward); + } + + t++; + } + } + + private static double rewardSoloWenz( + int callerAugen, int opponentAugen, + bool isCaller, bool isTout) + => isTout + ? (isCaller && callerAugen == 120 ? 1 : 0) + : ((isCaller && callerAugen >= 61) || (!isCaller && opponentAugen >= 60) ? 1 : 0); + + private static double rewardSauspiel( + int ownAugen, int partnerAugen, + bool isCaller, bool knowsPartner) + { + int effAugen = knowsPartner ? ownAugen + partnerAugen : ownAugen; + return (isCaller && effAugen >= 61) || (!isCaller && effAugen >= 60) ? 1 : 0; + } }