Skip to content

Commit

Permalink
unroll rewards from game history
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 28, 2023
1 parent d71e7b9 commit 094b2f9
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 83 deletions.
47 changes: 0 additions & 47 deletions Schafkopf.Lib/GameLog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,53 +151,6 @@ public Turn NextCard(Card card)
else
return Turns[t_id];
}

public IEnumerable<GameAction> UnrollActions()
{
var turnCache = new Card[4];
var action = new GameAction();

foreach (var turn in Turns)
{
int p_id = turn.FirstDrawingPlayerId;
turn.CopyCards(turnCache);

for (int i = 0; i < turn.CardsCount; i++)
{
var card = turnCache[p_id];
action.PlayerId = (byte)p_id;
action.CardPlayed = card;
yield return action;
p_id = (p_id + 1) % 4;
}
}
}

public IEnumerable<Hand> UnrollHands()
{
int i = 0;
var hands = InitialHands.ToArray();
foreach (var action in UnrollActions())
{
if (i++ >= CardCount)
break;
yield return hands[action.PlayerId];
hands[action.PlayerId] = hands[action.PlayerId].Discard(action.CardPlayed);
}
if (CardCount == 32)
yield return Hand.EMPTY;
}

public IEnumerable<int[]> UnrollAugen()
{
var augen = new int[4];
foreach (var turn in Turns)
{
yield return augen;
augen[turn.WinnerId] += turn.Augen;
}
yield return augen;
}
}

public struct GameAction : IEquatable<GameAction>
Expand Down
5 changes: 4 additions & 1 deletion Schafkopf.Lib/GameResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ public class GameScoreEvaluation
{
public GameScoreEvaluation(GameLog log)
{
var augen = log.UnrollAugen().Last();
var augen = new int[4];
foreach (var turn in log.Turns)
augen[turn.WinnerId] += turn.Augen;

ScoreCaller = log.CallerIds.ToArray().Select(id => augen[id]).Sum();
ScoreOpponents = 120 - ScoreCaller;

Expand Down
3 changes: 1 addition & 2 deletions Schafkopf.Training/Dataset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ private static IEnumerable<SarsExp> generateExperiences(
if (log.Call.Mode == GameMode.Weiter)
continue;

serializer.SerializeSarsExps(
log, expBuffer, GameReward.Reward);
serializer.SerializeSarsExps(log, expBuffer);

for (int i = 0; i < 32; i++)
if (numExamples == null || p++ < numExamples)
Expand Down
165 changes: 132 additions & 33 deletions Schafkopf.Training/GameState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ public static GameState[] NewBuffer()
=> Enumerable.Range(0, 36).Select(x => new GameState()).ToArray();

private GameState[] stateBuffer = NewBuffer();
public void SerializeSarsExps(
GameLog completedGame, SarsExp[] exps,
Func<GameLog, int, double> reward)
public void SerializeSarsExps(GameLog completedGame, SarsExp[] exps)
{
if (completedGame.CardCount != 32)
throw new ArgumentException("Can only process finished games!");
serializeHistory(completedGame, stateBuffer);

var actions = completedGame.UnrollActions().GetEnumerator();
var rewards = completedGame.UnrollRewards().GetEnumerator();
for (int t0 = 0; t0 < 32; t0++)
{
rewards.MoveNext();
actions.MoveNext();
var card = actions.Current.CardPlayed;
int p_id = actions.Current.PlayerId;
Expand All @@ -58,7 +58,7 @@ public void SerializeSarsExps(
exps[t0].StateBefore.LoadFeatures(stateBuffer[t0].State);
exps[t0].StateAfter.LoadFeatures(stateBuffer[t1].State);
exps[t0].IsTerminal = isTerminal;
exps[t0].Reward = reward(completedGame, t1);
exps[t0].Reward = rewards.Current.Item2;
}
}

Expand Down Expand Up @@ -246,37 +246,136 @@ public static unsafe int EncodeOnehot(double* stateArr, bool flag)
}
}

public class GameReward
public static class GameLogEx
{
public static double Reward(GameLog log, int t)
public static IEnumerable<GameAction> UnrollActions(this GameLog log)
{
// intention of this reward system:
// - players receive reward 1 as soon as they are in a winning state
// - if they are in a losing or undetermined state, they receive reward 0

// info: t >= 32 relate to the final game outcome
// from the view of the player with p_id = t%4
int playerId = t >= 32 ? t % 4 :
normPlayerId(log.Turns[t / 4].FirstDrawingPlayerId, t % 4);

// info: players don't know yet who the sauspiel partner is
// -> no reward, even if it's already won
bool alreadyGsucht = t < 28 ? log.Turns[t / 4].AlreadyGsucht : true;
if (log.Call.Mode == GameMode.Sauspiel && !alreadyGsucht)
return 0;

bool isCaller = log.CallerIds.Contains(playerId);
var augen = log.UnrollAugen().Last();
double callerScore = 0;
for (int i = 0; i < log.CallerIds.Length; i++)
callerScore += augen[log.CallerIds[i]];

if (log.Call.Mode != GameMode.Sauspiel && log.Call.IsTout)
return isCaller && callerScore == 120 ? 1 : 0;
else
return (isCaller && callerScore >= 61) || !isCaller ? 1 : 0;
var turnCache = new Card[4];
var action = new GameAction();

foreach (var turn in log.Turns)
{
int p_id = turn.FirstDrawingPlayerId;
turn.CopyCards(turnCache);

for (int i = 0; i < turn.CardsCount; i++)
{
var card = turnCache[p_id];
action.PlayerId = (byte)p_id;
action.CardPlayed = card;
yield return action;
p_id = (p_id + 1) % 4;
}
}
}

private static int normPlayerId(int id, int offset)
=> (id - offset + 4) & 0x03;
public static IEnumerable<int> UnrollActingPlayers(this GameLog log)
{
foreach (var turn in log.Turns)
{
int p_id = turn.FirstDrawingPlayerId;

for (int i = 0; i < turn.CardsCount; i++)
{
yield return p_id;
p_id = (p_id + 1) % 4;
}
}
}

public static IEnumerable<Hand> UnrollHands(this GameLog log)
{
int i = 0;
var hands = log.InitialHands.ToArray();
foreach (var action in log.UnrollActions())
{
if (i++ >= log.CardCount)
break;
yield return hands[action.PlayerId];
hands[action.PlayerId] = hands[action.PlayerId].Discard(action.CardPlayed);
}
if (log.CardCount == 32)
yield return Hand.EMPTY;
}

public static IEnumerable<int[]> UnrollAugen(this GameLog log)
{
var augen = new int[4];
foreach (var turn in log.Turns)
{
yield return augen;
augen[turn.WinnerId] += turn.Augen;
}
yield return augen;
}
}

public static class GameReward
{
public static IEnumerable<(int, double)> UnrollRewards(this GameLog completeGame)
{
int callerId = completeGame.Call.CallingPlayerId;
int partnerId = completeGame.Call.PartnerPlayerId;
var oppIds = completeGame.OpponentIds.ToArray();
var augenIter = completeGame.UnrollAugen().GetEnumerator();
augenIter.MoveNext();

int t = 0;
foreach (var action in completeGame.UnrollActions())
{
if (t % 4 == 0)
augenIter.MoveNext();
var augen = augenIter.Current;

if (completeGame.Call.Mode == GameMode.Sauspiel)
{
int p_id = action.PlayerId;
bool isCaller = p_id == callerId;
bool isPartner = p_id == partnerId;

int ownAugen = augen[p_id];
int partnerAugen;
if (isCaller)
partnerAugen = augen[partnerId];
else if (isPartner)
partnerAugen = augen[callerId];
else if (p_id == oppIds[0])
partnerAugen = augen[oppIds[1]];
else // if (p_id == oppIds[0])
partnerAugen = augen[oppIds[0]];

bool knowsPartner = isPartner || completeGame.Turns[t / 4].AlreadyGsucht;
double reward = rewardSauspiel(
ownAugen, partnerAugen, isCaller || isPartner, knowsPartner);
yield return (p_id, reward);
}
else // Wenz or Solo
{
int p_id = action.PlayerId;
int callerAugen = augen[callerId];
int opponentAugen = augen.Sum() - callerAugen;
bool isCaller = p_id == callerId;
bool isTout = completeGame.Call.IsTout;
double reward = rewardSoloWenz(callerAugen, opponentAugen, isCaller, isTout);
yield return (p_id, reward);
}

t++;
}
}

private static double rewardSoloWenz(
int callerAugen, int opponentAugen,
bool isCaller, bool isTout)
=> isTout
? (isCaller && callerAugen == 120 ? 1 : 0)
: ((isCaller && callerAugen >= 61) || (!isCaller && opponentAugen >= 60) ? 1 : 0);

private static double rewardSauspiel(
int ownAugen, int partnerAugen,
bool isCaller, bool knowsPartner)
{
int effAugen = knowsPartner ? ownAugen + partnerAugen : ownAugen;
return (isCaller && effAugen >= 61) || (!isCaller && effAugen >= 60) ? 1 : 0;
}
}

0 comments on commit 094b2f9

Please sign in to comment.