Skip to content

Commit

Permalink
fill ppo rollout buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 29, 2023
1 parent 094b2f9 commit 0b2e226
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 49 deletions.
165 changes: 129 additions & 36 deletions Schafkopf.Training/Algos/MDP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,153 @@ public CardPickerExpCollector(
private PPOModel strategy;
private PossibleCardPicker cardSampler;

private struct TurnBatches
{
public TurnBatches(int numSessions)
{
s0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 90)).ToArray();
a0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
piBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
piSparseBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
vBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
}

public Matrix2D[] s0Batches { get; set; }
public Matrix2D[] a0Batches { get; set; }
public Matrix2D[] piBatches { get; set; }
public Matrix2D[] piSparseBatches { get; set; }
public Matrix2D[] vBatches { get; set; }
}

public void Collect(PPORolloutBuffer buffer)
{
if (buffer.NumEnvs % 4 != 0)
throw new ArgumentException("The number of envs needs to be "
+ "divisible by 4 because 4 agents are playing the game!");
if (buffer.Steps % 8 != 0)
throw new ArgumentException("The number of steps needs to be "
+ "divisible by 8 because each agent plays 8 cards per game!");

int numGames = buffer.NumEnvs / 4;
var envs = Enumerable.Range(0, numGames)
int numGames = buffer.Steps / 8;
int numSessions = buffer.NumEnvs / 4;
var envs = Enumerable.Range(0, numSessions)
.Select(i => new CardPickerEnv()).ToArray();
var states = envs.Select(env => env.Reset()).ToArray();
var s0Batch = Matrix2D.Zeros(numGames, 90);
var a0Batch = Matrix2D.Zeros(numGames, 1);
var piBatch = Matrix2D.Zeros(numGames, 32);
var piSparse = Matrix2D.Zeros(numGames, 32);
var vBatch = Matrix2D.Zeros(numGames, 1);
var cardsCache = new Card[8];
var logsCache = new GameLog[numGames];
var predBuffer = new PPOPredictionCache(buffer.NumEnvs, 8);

int step = 0;
while (step < buffer.Steps)
var batchesOfTurns = Enumerable.Range(0, 8)
.Select(i => new TurnBatches(numSessions)).ToArray();
var rewards = Matrix2D.Zeros(8, buffer.NumEnvs);

for (int gameId = 0; gameId < numGames + 1; gameId++)
{
for (int envId = 0; envId < states.Length; envId++)
{
var s0 = stateSerializer.SerializeState(states[envId]);
unsafe { s0.ExportFeatures(s0Batch.Data + envId * 90); }
}
playGame(envs, states, batchesOfTurns);
prepareRewards(states, rewards);
fillBuffer(gameId, buffer, states, batchesOfTurns, rewards);
}
}

strategy.Predict(s0Batch, piBatch, vBatch);
private void fillBuffer(
int gameId, PPORolloutBuffer buffer, GameLog[] states,
TurnBatches[] batchesOfTurns, Matrix2D rewards)
{
for (int t_id = 0; t_id < 8; t_id++)
{
var expBufNull = buffer.SliceStep(gameId * 8 + t_id);
if (expBufNull == null) return;
var expBuf = expBufNull.Value;

var batches = batchesOfTurns[t_id];
var r1Batch = rewards.SliceRows(t_id, 1);

var actions = a0Batch.SliceRowsRaw(0, numGames);
var selProbs = piSparse.SliceRowsRaw(0, numGames);
for (int envId = 0; envId < numGames; envId++)
for (int envId = 0; envId < states.Length; envId++)
{
var piSlice = piBatch.SliceRowsRaw(envId, 1);
var possCards = rules.PossibleCards(states[envId], cardsCache);
var card = cardSampler.PickCard(possCards, piSlice);
int action = card.Id % 32;
actions[envId] = action;
selProbs[envId] = piSlice[action];
var p_ids = states[envId].UnrollActingPlayers()
.Skip(t_id * 4).Take(4).Zip(Enumerable.Range(0, 4));
foreach ((int p_id, int i) in p_ids)
{
var s0Batch = batches.s0Batches[i];
var a0Batch = batches.a0Batches[i];
var vBatch = batches.vBatches[i];
var piSparseBatch = batches.piSparseBatches[i];

int rowid = envId * 4 + p_id;
Matrix2D.CopyData(
s0Batch.SliceRows(envId, 1),
expBuf.StatesBefore.SliceRows(rowid, 1));

unsafe
{
expBuf.Actions.Data[rowid] = a0Batch.Data[envId];
expBuf.Rewards.Data[rowid] = rewards.Data[envId];
expBuf.Terminals.Data[rowid] = t_id == 7 ? 1 : 0;
expBuf.OldProbs.Data[rowid] = piSparseBatch.Data[envId];
expBuf.OldBaselines.Data[rowid] = vBatch.Data[envId];
}
}
}
}
}

for (int envId = 0; envId < numGames; envId++)
private void prepareRewards(GameLog[] states, Matrix2D rewards)
{
for (int envId = 0; envId < states.Length; envId++)
{
var finalState = states[envId];
foreach ((int t_id, var p_id, var reward) in finalState.UnrollRewards())
{
(var newState, double reward, bool isTerminal) =
envs[envId].Step(new Card((byte)actions[envId]));
states[envId] = newState;
int rowid = states.Length * 4 * t_id + envId * 4 + p_id;
unsafe { rewards.Data[rowid] = reward; }
}
}
}

private Card[] cardsCache = new Card[8];
private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batchesOfTurns)
{
for (int t_id = 0; t_id < 8; t_id++)
{
var batches = batchesOfTurns[t_id];

if (step % 32 == 0)
for (int i = 0; i < 4; i++)
{
// TODO: continue implementation
// 1) evaluate the game logs to determine rewards
// 2) fill the training buffer with game histories
var s0Batch = batches.s0Batches[i];
var a0Batch = batches.a0Batches[i];
var piBatch = batches.piBatches[i];
var vBatch = batches.vBatches[i];
var piSparseBatch = batches.piSparseBatches[i];

for (int envId = 0; envId < states.Length; envId++)
{
var s0 = stateSerializer.SerializeState(states[envId]);
unsafe { s0.ExportFeatures(s0Batch.Data + envId * 90); }
}

strategy.Predict(s0Batch, piBatch, vBatch);

var actions = a0Batch.SliceRowsRaw(0, envs.Length);
var selProbs = piSparseBatch.SliceRowsRaw(0, envs.Length);
for (int envId = 0; envId < envs.Length; envId++)
{
var piSlice = piBatch.SliceRowsRaw(envId, 1);
var possCards = rules.PossibleCards(states[envId], cardsCache);
var card = cardSampler.PickCard(possCards, piSlice);
int action = card.Id % 32;
actions[envId] = action;
selProbs[envId] = piSlice[action];
}

for (int envId = 0; envId < envs.Length; envId++)
{
// info: rewards and terminals are
// determined after the game is over
(var newState, double reward, bool isTerminal) =
envs[envId].Step(new Card((byte)actions[envId]));
states[envId] = newState;
}
}
}
}
Expand Down
12 changes: 2 additions & 10 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,11 @@ public PPORolloutBuffer(PPOTrainingSettings config)

public bool IsReadyForModelUpdate(int t) => t > 0 && t % Steps == 0;

public void AppendStep(PPOTrainBatch expsOfStep, int t)
public PPOTrainBatch? SliceStep(int t)
{
int offset = IsReadyForModelUpdate(t)
? Steps * NumEnvs : (t % Steps) * NumEnvs;

Matrix2D.CopyData(expsOfStep.StatesBefore, cache.StatesBefore.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Actions, cache.Actions.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Rewards, cache.Rewards.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Terminals, cache.Terminals.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Returns, cache.Returns.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.Advantages, cache.Advantages.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.OldProbs, cache.OldProbs.SliceRows(offset, NumEnvs));
Matrix2D.CopyData(expsOfStep.OldBaselines, cache.OldBaselines.SliceRows(offset, NumEnvs));
return (t > Steps) ? null : cache.SliceRows(offset, NumEnvs);
}

public IEnumerable<PPOTrainBatch> SampleDataset(int batchSize, int epochs = 1)
Expand Down
6 changes: 3 additions & 3 deletions Schafkopf.Training/GameState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ public static IEnumerable<int[]> UnrollAugen(this GameLog log)

public static class GameReward
{
public static IEnumerable<(int, double)> UnrollRewards(this GameLog completeGame)
public static IEnumerable<(int, int, double)> UnrollRewards(this GameLog completeGame)
{
int callerId = completeGame.Call.CallingPlayerId;
int partnerId = completeGame.Call.PartnerPlayerId;
Expand Down Expand Up @@ -347,7 +347,7 @@ public static class GameReward
bool knowsPartner = isPartner || completeGame.Turns[t / 4].AlreadyGsucht;
double reward = rewardSauspiel(
ownAugen, partnerAugen, isCaller || isPartner, knowsPartner);
yield return (p_id, reward);
yield return (t / 4, p_id, reward);
}
else // Wenz or Solo
{
Expand All @@ -357,7 +357,7 @@ public static class GameReward
bool isCaller = p_id == callerId;
bool isTout = completeGame.Call.IsTout;
double reward = rewardSoloWenz(callerAugen, opponentAugen, isCaller, isTout);
yield return (p_id, reward);
yield return (t / 4, p_id, reward);
}

t++;
Expand Down

0 comments on commit 0b2e226

Please sign in to comment.