Skip to content

Commit

Permalink
include reward into multi agent env
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 6, 2023
1 parent e43323c commit 6b7da03
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 58 deletions.
3 changes: 2 additions & 1 deletion Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ private void prepareRewards(GameLog[] states, Matrix2D rewards)
for (int envId = 0; envId < states.Length; envId++)
{
var finalState = states[envId];
foreach ((int t_id, var p_id, var reward) in finalState.UnrollRewards())
foreach ((int t_id, var p_id, var reward)
in CardPickerReward.UnrollRewards(finalState))
{
int rowid = states.Length * 4 * t_id + envId * 4 + p_id;
unsafe { rewards.Data[rowid] = reward; }
Expand Down
95 changes: 56 additions & 39 deletions Schafkopf.Training/Common/GameState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void SerializeSarsExps(GameLog completedGame, SarsExp[] exps)
serializeHistory(completedGame, stateBuffer);

var actions = completedGame.UnrollActions().GetEnumerator();
var rewards = completedGame.UnrollRewards().GetEnumerator();
var rewards = CardPickerReward.UnrollRewards(completedGame).GetEnumerator();
for (int t0 = 0; t0 < 32; t0++)
{
rewards.MoveNext();
Expand Down Expand Up @@ -309,13 +309,18 @@ public static IEnumerable<int[]> UnrollAugen(this GameLog log)
}
}

public static class GameReward
public static class CardPickerReward
{
public static IEnumerable<(int, int, double)> UnrollRewards(this GameLog completeGame)
public static double Reward(GameLog liveGame, int p_id)
{
int t_id = Math.Min(liveGame.CardCount / 4, 7);
var augen = liveGame.UnrollAugen().Skip(t_id).First();
var turn = liveGame.Turns[t_id];
return reward(liveGame.Call, p_id, augen, turn.AlreadyGsucht);
}

public static IEnumerable<(int, int, double)> UnrollRewards(GameLog completeGame)
{
int callerId = completeGame.Call.CallingPlayerId;
int partnerId = completeGame.Call.PartnerPlayerId;
var oppIds = completeGame.OpponentIds.ToArray();
var augenIter = completeGame.UnrollAugen().GetEnumerator();
augenIter.MoveNext();

Expand All @@ -325,41 +330,53 @@ public static class GameReward
if (t % 4 == 0)
augenIter.MoveNext();
var augen = augenIter.Current;
int t_id = Math.Min(t / 4, 7);
int p_id = action.PlayerId;
var turn = completeGame.Turns[t_id];
bool alreadyGsucht = turn.AlreadyGsucht;
double r = reward(completeGame.Call, p_id, augen, alreadyGsucht);
yield return (t_id, p_id, r);
t++;
}
}

if (completeGame.Call.Mode == GameMode.Sauspiel)
{
int p_id = action.PlayerId;
bool isCaller = p_id == callerId;
bool isPartner = p_id == partnerId;

int ownAugen = augen[p_id];
int partnerAugen;
if (isCaller)
partnerAugen = augen[partnerId];
else if (isPartner)
partnerAugen = augen[callerId];
else if (p_id == oppIds[0])
partnerAugen = augen[oppIds[1]];
else // if (p_id == oppIds[0])
partnerAugen = augen[oppIds[0]];

bool knowsPartner = isPartner || completeGame.Turns[t / 4].AlreadyGsucht;
double reward = rewardSauspiel(
ownAugen, partnerAugen, isCaller || isPartner, knowsPartner);
yield return (t / 4, p_id, reward);
}
else // Wenz or Solo
{
int p_id = action.PlayerId;
int callerAugen = augen[callerId];
int opponentAugen = augen.Sum() - callerAugen;
bool isCaller = p_id == callerId;
bool isTout = completeGame.Call.IsTout;
double reward = rewardSoloWenz(callerAugen, opponentAugen, isCaller, isTout);
yield return (t / 4, p_id, reward);
}
private static double reward(
GameCall call, int p_id,
int[] augen, bool alreadyGsucht)
{
int callerId = call.CallingPlayerId;
int partnerId = call.PartnerPlayerId;
var oppIds = call.OpponentIds;

t++;
if (call.Mode == GameMode.Sauspiel)
{
bool isCaller = p_id == callerId;
bool isPartner = p_id == partnerId;

int ownAugen = augen[p_id];
int partnerAugen;
if (isCaller)
partnerAugen = augen[partnerId];
else if (isPartner)
partnerAugen = augen[callerId];
else if (p_id == oppIds[0])
partnerAugen = augen[oppIds[1]];
else // if (p_id == oppIds[0])
partnerAugen = augen[oppIds[0]];

bool knowsPartner = isPartner || alreadyGsucht;
double reward = rewardSauspiel(
ownAugen, partnerAugen, isCaller || isPartner, knowsPartner);
return reward;
}
else // Wenz or Solo
{
int callerAugen = augen[callerId];
int opponentAugen = augen.Sum() - callerAugen;
bool isCaller = p_id == callerId;
bool isTout = call.IsTout;
double reward = rewardSoloWenz(callerAugen, opponentAugen, isCaller, isTout);
return reward;
}
}

Expand Down
37 changes: 19 additions & 18 deletions Schafkopf.Training/Common/MDP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ public MultiAgentCardPickerEnv()
{
env = new CardPickerEnv();
threadIds = new int[4];
termBarr = new Barrier(4);
restBarr = new Barrier(4, (b) => { state = env.Reset(); });
gameFinishedBarr = new Barrier(4);
resetBarr = new Barrier(4, (b) => { state = env.Reset(); });
stateModMut = new Mutex();
state = env.Reset();
}

private CardPickerEnv env;
private GameLog state;
private int[] threadIds;
private Barrier termBarr;
private Barrier restBarr;
private Barrier gameFinishedBarr;
private Barrier resetBarr;
private Mutex stateModMut;

private int playerIdByThread()
Expand All @@ -104,45 +104,46 @@ public void Register(int playerId)
public GameLog Reset()
{
int playerId = playerIdByThread();
while (state.DrawingPlayerId != playerId)
while (!isPlayersTurn(playerId))
Thread.Sleep(1);
return state;
}

public (GameLog, double, bool) Step(Card cardToPlay)
{
int t_id = state.CardCount / 4;
bool isTermial = state.CardCount >= 28;
int playerId = playerIdByThread();

stateModMut.WaitOne();
(state, var _, var __) = env.Step(cardToPlay);
stateModMut.ReleaseMutex();

if (isTermial)
{
// info: wait until game is in final state
termBarr.SignalAndWait();
// wait for last turn to finish, cache final state
gameFinishedBarr.SignalAndWait();
var finalState = state;
restBarr.SignalAndWait();

// TODO: include reward computation here ...
return (finalState, 0.0, true);
// start a new game after all agents cached the final state
resetBarr.SignalAndWait();

double reward = CardPickerReward.Reward(finalState, playerId);
return (finalState, reward, true);
}
else
{
// info: wait until it's the player's turn again
int playerId = playerIdByThread();
while (checkId(playerId))
while (!isPlayersTurn(playerId))
Thread.Sleep(1);

// TODO: include reward computation here ...
return (state, 0.0, false);
double reward = CardPickerReward.Reward(state, playerId);
return (state, reward, false);
}
}

private bool checkId(int playerId)
private bool isPlayersTurn(int playerId)
{
stateModMut.WaitOne();
bool ret = state.DrawingPlayerId != playerId;
bool ret = state.DrawingPlayerId == playerId;
stateModMut.ReleaseMutex();
return ret;
}
Expand Down

0 comments on commit 6b7da03

Please sign in to comment.