Skip to content

Commit

Permalink
set up ppo training session (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 30, 2023
1 parent c24f281 commit f84ee37
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 97 deletions.
40 changes: 0 additions & 40 deletions Schafkopf.Training/Algos/MDP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,46 +162,6 @@ public TurnBatches(int numSessions)
public Matrix2D[] piSparseBatches { get; set; }
public Matrix2D[] vBatches { get; set; }
}

private class PossibleCardPicker
{
private UniformDistribution uniform = new UniformDistribution();

public Card PickCard(
ReadOnlySpan<Card> possibleCards,
ReadOnlySpan<double> predPi,
Card sampledCard)
=> canPlaySampledCard(possibleCards, sampledCard) ? sampledCard
: possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

public Card PickCard(ReadOnlySpan<Card> possibleCards, ReadOnlySpan<double> predPi)
=> possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

private bool canPlaySampledCard(
ReadOnlySpan<Card> possibleCards, Card sampledCard)
{
foreach (var card in possibleCards)
if (card == sampledCard)
return true;
return false;
}

private double[] probDistCache = new double[8];
private ReadOnlySpan<double> normProbDist(
ReadOnlySpan<double> probDistAll, ReadOnlySpan<Card> possibleCards)
{
double probSum = 0;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK];
for (int i = 0; i < possibleCards.Length; i++)
probSum += probDistCache[i];
double scale = 1 / probSum;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] *= scale;

return probDistCache.AsSpan().Slice(0, possibleCards.Length);
}
}
}

public class CardPickerEnv
Expand Down
145 changes: 122 additions & 23 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,95 @@

namespace Schafkopf.Training;

public class PPOTrainingSettings
{
public int TotalSteps = 10_000_000;
public double LearnRate = 3e-4;
public double RewardDiscount = 0.99;
public double GAEDiscount = 0.95;
public double ProbClip = 0.2;
public double ValueClip = 0.2;
public double VFCoef = 0.5;
public double EntCoef = 0.01;
public bool NormAdvantages = true;
public bool ClipValues = true;
public int BatchSize = 64;
public int NumEnvs = 32;
public int NumStateDims = 90;
public int NumActionDims = 32;
public int StepsPerUpdate = 512;
public int UpdateEpochs = 10;
public int NumModelSnapshots = 20;

public int TrainSteps => TotalSteps / NumEnvs;
public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots;
public int NumTrainings => TrainSteps / StepsPerUpdate;
}

public class PPOTrainingSession
{
public void Train()
public PPOModel Train(PPOTrainingSettings config)
{

var model = new PPOModel(config);
var rollout = new PPORolloutBuffer(config);
var exps = new CardPickerExpCollector(model);
var benchmark = new RandomPlayBenchmark();
var agent = new PPOAgent(model);

for (int ep = 0; ep < config.NumTrainings; ep++)
{
exps.Collect(rollout);
model.Train(rollout);

if ((ep + 1) % 10 == 0)
{
double winRate = benchmark.Benchmark(agent);
Console.WriteLine($"epoch {ep}: win rate vs. random agents is {winRate}");
}
}

return model;
}
}

public class PPOAgent : ISchafkopfAIAgent
{
public PPOAgent(PPOModel model)
{
this.model = model;
}

private PPOModel model;
private HeuristicAgent heuristicAgent = new HeuristicAgent();
private GameStateSerializer stateSerializer = new GameStateSerializer();
private PossibleCardPicker sampler = new PossibleCardPicker();
private UniformDistribution uniform = new UniformDistribution();

private Matrix2D s0 = Matrix2D.Zeros(1, 90);
private Matrix2D piOh = Matrix2D.Zeros(1, 32);
private Matrix2D V = Matrix2D.Zeros(1, 1);

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
{
var state = stateSerializer.SerializeState(log);
state.ExportFeatures(s0.SliceRowsRaw(0, 1));
model.Predict(s0, piOh, V);
var predDist = piOh.SliceRowsRaw(0, 1);
var card = new Card((byte)uniform.Sample(predDist));
return sampler.PickCard(possibleCards, predDist, card);
}

public bool CallKontra(GameLog log) => heuristicAgent.CallKontra(log);
public bool CallRe(GameLog log) => heuristicAgent.CallRe(log);
public bool IsKlopfer(int position, ReadOnlySpan<Card> firstFourCards)
=> heuristicAgent.IsKlopfer(position, firstFourCards);
public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> heuristicAgent.MakeCall(possibleCalls, position, hand, klopfer);
public void OnGameFinished(GameLog final) => heuristicAgent.OnGameFinished(final);
}

public class PPOModel
{
public PPOModel(PPOTrainingSettings config)
Expand Down Expand Up @@ -65,11 +147,13 @@ public void Train(PPORolloutBuffer memory)

private void updateModels(PPOTrainBatch batch)
{
// update strategy pi(s)
var predPi = strategy.PredictBatch(batch.StatesBefore);
var policyDeltas = strategy.Layers.Last().Cache.DeltasIn;
computePolicyDeltas(batch, predPi, policyDeltas);
strategy.FitBatch(policyDeltas, strategyOpt);

// update baseline V(s)
var predV = valueFunc.PredictBatch(batch.StatesBefore);
var valueDeltas = valueFunc.Layers.Last().Cache.DeltasIn;
computeValueDeltas(batch, predV, valueDeltas);
Expand Down Expand Up @@ -140,29 +224,44 @@ private IEnumerable<int> onehotIndices(Matrix2D sparseClassIds, int numClasses)
}
}

public class PPOTrainingSettings
public class PossibleCardPicker
{
public int NumObsFeatures { get; set; }
public int TotalSteps = 10_000_000;
public double LearnRate = 3e-4;
public double RewardDiscount = 0.99;
public double GAEDiscount = 0.95;
public double ProbClip = 0.2;
public double ValueClip = 0.2;
public double VFCoef = 0.5;
public double EntCoef = 0.01;
public bool NormAdvantages = true;
public bool ClipValues = true;
public int BatchSize = 64;
public int NumEnvs = 32;
public int NumStateDims = 90;
public int NumActionDims = 32;
public int StepsPerUpdate = 512;
public int UpdateEpochs = 4;
public int NumModelSnapshots = 20;
private UniformDistribution uniform = new UniformDistribution();

public int TrainSteps => TotalSteps / NumEnvs;
public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots;
public Card PickCard(
ReadOnlySpan<Card> possibleCards,
ReadOnlySpan<double> predPi,
Card sampledCard)
=> canPlaySampledCard(possibleCards, sampledCard) ? sampledCard
: possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

public Card PickCard(ReadOnlySpan<Card> possibleCards, ReadOnlySpan<double> predPi)
=> possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

private bool canPlaySampledCard(
ReadOnlySpan<Card> possibleCards, Card sampledCard)
{
foreach (var card in possibleCards)
if (card == sampledCard)
return true;
return false;
}

private double[] probDistCache = new double[8];
private ReadOnlySpan<double> normProbDist(
ReadOnlySpan<double> probDistAll, ReadOnlySpan<Card> possibleCards)
{
double probSum = 0;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK];
for (int i = 0; i < possibleCards.Length; i++)
probSum += probDistCache[i];
double scale = 1 / probSum;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] *= scale;

return probDistCache.AsSpan().Slice(0, possibleCards.Length);
}
}

public struct PPOTrainBatch
Expand Down
19 changes: 9 additions & 10 deletions Schafkopf.Training/GameState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip
};

int t = 0;
for (int t_id = 0; t_id < Math.Ceiling((double)timesteps / 4); t_id++)
for (int t_id = 0; t_id < Math.Min(timesteps / 4 + 1, 8); t_id++)
{
scores.MoveNext();

Expand All @@ -95,13 +95,15 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip

if (t >= skip)
{
int actingPlayer = t == history.CardCount
? history.DrawingPlayerId : allActions[t].PlayerId;
var hand = hands.Current;
var score = scores.Current;
var state = statesCache[t].State;
serializeState(state, normCalls, hand, t++, allActions, score);
serializeState(state, normCalls, hand, t, actingPlayer, allActions, score);
}

if (t == timesteps) return;
if (t++ == history.CardCount) return; // TODO: check this condition
}
}

Expand All @@ -110,12 +112,12 @@ private void serializeHistory(GameLog history, GameState[] statesCache, int skip
for (; t < 36; t++)
serializeState(
statesCache[t].State, normCalls, Hand.EMPTY,
t, allActions, scores.Current);
t, t % 4, allActions, scores.Current);
}

private unsafe void serializeState(
double[] state, ReadOnlySpan<GameCall> normCalls, Hand hand, int t,
ReadOnlySpan<GameAction> turnHistory, int[] augen)
int actingPlayer, ReadOnlySpan<GameAction> turnHistory, int[] augen)
{
if (state.Length < 90)
throw new IndexOutOfRangeException("Memory overflow");
Expand All @@ -126,7 +128,7 @@ private unsafe void serializeState(
// - turn history (64 floats)
// - augen (4 floats)

int actingPlayer = t < 32 ? turnHistory[t].PlayerId : t & 0x3;
// int actingPlayer = t < 32 ? (t == 0 ? kommtRaus : turnHistory[t].PlayerId) : t % 4;
var call = normCalls[actingPlayer];

fixed (double* stateArr = &state[0])
Expand Down Expand Up @@ -250,17 +252,14 @@ public static class GameLogEx
{
public static IEnumerable<GameAction> UnrollActions(this GameLog log)
{
var turnCache = new Card[4];
var action = new GameAction();

foreach (var turn in log.Turns)
{
int p_id = turn.FirstDrawingPlayerId;
turn.CopyCards(turnCache);

for (int i = 0; i < turn.CardsCount; i++)
foreach (var card in turn.AllCards)
{
var card = turnCache[p_id];
action.PlayerId = (byte)p_id;
action.CardPlayed = card;
yield return action;
Expand Down
24 changes: 3 additions & 21 deletions Schafkopf.Training/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,10 @@

public class Program
{
private static FFModel createModel()
=> new FFModel(new ILayer[] {
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(1),
});

public static void Main(string[] args)
{
var model = createModel();
var dataset = SupervisedSchafkopfDataset.GenerateDataset(
trainSize: 1_000_000, testSize: 10_000);
var optimizer = new AdamOpt(learnRate: 0.002);
var lossFunc = new MeanSquaredError();

var session = new SupervisedTrainingSession(
model, optimizer, lossFunc, dataset);
Console.WriteLine("Training started!");
Console.WriteLine($"loss before: loss={session.Eval()}");
session.Train(5, true, (ep, l) => Console.WriteLine($"loss ep. {ep}: loss={l}"));
Console.WriteLine("Training finished!");
var config = new PPOTrainingSettings();
var session = new PPOTrainingSession();
session.Train(config);
}
}
9 changes: 6 additions & 3 deletions Schafkopf.Training/RandomPlayBenchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ namespace Schafkopf.Training;

public class RandomPlayBenchmark
{
public void Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000)
public double Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000)
{
var gameCaller = new HeuristicGameCaller(
new GameMode[] { GameMode.Sauspiel, GameMode.Wenz, GameMode.Solo });
Expand All @@ -22,13 +22,16 @@ public void Benchmark(ISchafkopfAIAgent agentToEval, int epochs = 10_000)
for (int i = 0; i < epochs; i++)
{
var log = session.ProcessGame();

// info: only evaluate games where cards were played
if (log.Call.Mode == GameMode.Weiter) { i--; continue; }

var eval = new GameScoreEvaluation(log);
bool isCaller = log.CallerIds.Contains(0);
bool isWin = !eval.DidCallerWin ^ isCaller;
wins += isWin ? 1 : 0;
}

double winRate = (double)wins / epochs;
Console.WriteLine($"agent scored a win rate of {winRate}!");
return (double)wins / epochs; // win rate
}
}

0 comments on commit f84ee37

Please sign in to comment.