Skip to content

Commit

Permalink
abstract ppo training with generic state and action
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 8, 2024
1 parent 1ceee28 commit ce5c8c4
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 253 deletions.
6 changes: 5 additions & 1 deletion Schafkopf.Lib/Card.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public enum CardColor
Eichel
}

public readonly struct Card
public readonly struct Card : IEquatable<Card>
{
public const byte EXISTING_FLAG = 0x20;
public const byte TRUMPF_FLAG = 0x40;
Expand Down Expand Up @@ -60,6 +60,9 @@ public Card(CardType type, CardColor color, bool exists, bool isTrumpf)
public override bool Equals([NotNullWhen(true)] object? obj)
=> obj is Card c && (c.Id & ORIG_CARD_MASK) == (this.Id & ORIG_CARD_MASK);

public bool Equals(Card other)
=> Equals((object?)other);

public override int GetHashCode() => Id & ORIG_CARD_MASK;

public static bool operator ==(Card a, Card b)
Expand All @@ -71,5 +74,6 @@ public override bool Equals([NotNullWhen(true)] object? obj)

public override string ToString()
=> $"{Color} {Type}{(IsTrumpf ? " (trumpf)" : "")}";

// TODO: add an emoji format
}
2 changes: 1 addition & 1 deletion Schafkopf.Training.Tests/FeatureVectorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public void Test_CanSerializeCompleteGame()
var call = GameCall.Sauspiel(0, 1, CardColor.Schell);
var history = generateHistoryWithCall(call);

var newExp = () => new SarsExp() { StateBefore = new GameState() };
var newExp = () => new SchafkopfSarsExp() { StateBefore = new GameState() };
var states = Enumerable.Range(0, 32).Select(i => newExp()).ToArray();
serializer.SerializeSarsExps(history, states);

Expand Down
263 changes: 27 additions & 236 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,70 +25,6 @@ public class PPOTrainingSettings
public int NumTrainings => TrainSteps / StepsPerUpdate;
}

public class PPOTrainingSession
{
public PPOModel Train(PPOTrainingSettings config)
{
var model = new PPOModel(config);
var rollout = new PPORolloutBuffer(config);
var exps = new CardPickerExpCollector();
var benchmark = new RandomPlayBenchmark();
var agent = new PPOAgent(model);

for (int ep = 0; ep < config.NumTrainings; ep++)
{
Console.WriteLine($"epoch {ep+1}");
exps.Collect(rollout, model);
model.Train(rollout);

model.RecompileCache(batchSize: 1);
double winRate = benchmark.Benchmark(agent);
model.RecompileCache(batchSize: config.BatchSize);

Console.WriteLine($"win rate vs. random agents: {winRate}");
Console.WriteLine("--------------------------------------");
}

return model;
}
}

public class PPOAgent : ISchafkopfAIAgent
{
public PPOAgent(PPOModel model)
{
this.model = model;
}

private PPOModel model;
private HeuristicAgent heuristicAgent = new HeuristicAgent();
private GameStateSerializer stateSerializer = new GameStateSerializer();
private PossibleCardPicker sampler = new PossibleCardPicker();

private Matrix2D s0 = Matrix2D.Zeros(1, 90);
private Matrix2D piOh = Matrix2D.Zeros(1, 32);
private Matrix2D V = Matrix2D.Zeros(1, 1);

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
{
var state = stateSerializer.SerializeState(log);
state.ExportFeatures(s0.SliceRowsRaw(0, 1));
model.Predict(s0, piOh, V);
var predDist = piOh.SliceRowsRaw(0, 1);
return sampler.PickCard(possibleCards, predDist);
}

public bool CallKontra(GameLog log) => heuristicAgent.CallKontra(log);
public bool CallRe(GameLog log) => heuristicAgent.CallRe(log);
public bool IsKlopfer(int position, ReadOnlySpan<Card> firstFourCards)
=> heuristicAgent.IsKlopfer(position, firstFourCards);
public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> heuristicAgent.MakeCall(possibleCalls, position, hand, klopfer);
public void OnGameFinished(GameLog final) => heuristicAgent.OnGameFinished(final);
}

public class PPOModel
{
public PPOModel(PPOTrainingSettings config)
Expand Down Expand Up @@ -139,7 +75,9 @@ public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV)
Matrix2D.CopyData(predV, outV);
}

public void Train(PPORolloutBuffer memory)
public void Train<TState, TAction>(PPORolloutBuffer<TState, TAction> memory)
where TState : IEquatable<TState>, new()
where TAction : IEquatable<TAction>, new()
{
int numBatches = memory.NumBatches(
config.BatchSize, config.UpdateEpochs);
Expand Down Expand Up @@ -228,175 +166,13 @@ public void RecompileCache(int batchSize)
}
}

public class CardPickerExpCollector
{
public void Collect(PPORolloutBuffer buffer, PPOModel strategy)
{
int numGames = buffer.Steps / 8;
int numSessions = buffer.NumEnvs / 4;
var envs = Enumerable.Range(0, numSessions)
.Select(i => new MultiAgentCardPickerEnv()).ToArray();

var vecAgent = new VectorizedCardPickerAgent(strategy, numSessions);
var agents = Enumerable.Range(0, buffer.NumEnvs)
.Select(i => new AsyncCardPickerAgent(vecAgent)).ToArray();

var expCache = new PPOExp[buffer.NumEnvs];
int t = 0;
var barr = new Barrier(buffer.NumEnvs, (b) => {
buffer.AppendStep(expCache, t++);
Console.Write($"\rcollecting ppo data {t} / {buffer.Steps} ");
});

var collectTasks = Enumerable.Range(0, buffer.NumEnvs)
.Select(i => Task.Run(() => {
var agent = agents[i];
var env = envs[i / 4];
foreach (var exp in agent.PlaySteps(i % 4, env, buffer.Steps))
{
barr.SignalAndWait();
expCache[i] = exp;
}
}))
.ToArray();

Task.WaitAll(collectTasks);
Console.WriteLine();
}
}

public class VectorizedCardPickerAgent
{
public VectorizedCardPickerAgent(PPOModel strategy, int numSessions)
{
states = Matrix2D.Zeros(numSessions, GameState.NUM_FEATURES);
predPi = Matrix2D.Zeros(numSessions, 32);
predV = Matrix2D.Zeros(numSessions, 1);

samplers = Enumerable.Range(0, numSessions)
.Select(i => new PossibleCardPicker()).ToArray();

threadIds = new int[numSessions];
barr = new Barrier(numSessions, (b) => strategy.Predict(states, predPi, predV));
}

private int[] threadIds;
private Barrier barr;

private Matrix2D states;
private Matrix2D predPi;
private Matrix2D predV;

private PossibleCardPicker[] samplers;

private int sessionIdByThread()
{
int threadId = Environment.CurrentManagedThreadId;
for (int i = 0; i < threadIds.Length; i++)
if (threadIds[i] == threadId)
return i;
throw new InvalidOperationException("Unregistered thread!");
}

public void Register(int sessionId)
{
threadIds[sessionId] = Environment.CurrentManagedThreadId;
}

public (Card, double, double) Predict(
GameState state, ReadOnlySpan<Card> possCards)
{
int sessionId = sessionIdByThread();
var s0Slice = states.SliceRowsRaw(sessionId, 1);
state.ExportFeatures(s0Slice);

barr.SignalAndWait();

var predPiDistr = predPi.SliceRowsRaw(sessionId, 1);
var card = samplers[sessionId].PickCard(possCards, predPiDistr);
double pi = predPiDistr[card.Id % 32];
double V = predV.At(sessionId, 0);

return (card, pi, V);
}
}

public class AsyncCardPickerAgent
{
public AsyncCardPickerAgent(VectorizedCardPickerAgent vecAgent)
{
this.vecAgent = vecAgent;
}

private VectorizedCardPickerAgent vecAgent;
private Card[] cardCache = new Card[8];
private GameRules rules = new GameRules();
private GameStateSerializer stateSerializer = new GameStateSerializer();

public IEnumerable<PPOExp> PlaySteps(
int playerId, MultiAgentCardPickerEnv env, int steps)
{
var exp = new PPOExp();
env.Register(playerId);
var state = env.Reset();

for (int i = 0; i < steps; i++)
{
(GameState s0, Card a0, double pi, double V) = predict(state);
(state, double r1, bool t1) = env.Step(a0);
if (t1)
state = env.Reset();

exp.StateBefore = s0;
exp.Action = a0;
exp.Reward = r1;
exp.IsTerminal = t1;
exp.OldProb = pi;
exp.OldBaseline = V;
yield return exp;
}
}

private (GameState, Card, double, double) predict(GameLog state)
{
var possCards = rules.PossibleCards(state, cardCache);
var encState = stateSerializer.SerializeState(state);
(var a0, var pi, var V) = vecAgent.Predict(encState, possCards);
return (encState, a0, pi, V);
}
}

public class PossibleCardPicker
{
private UniformDistribution uniform = new UniformDistribution();

public Card PickCard(ReadOnlySpan<Card> possibleCards, ReadOnlySpan<double> predPi)
=> possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

private double[] probDistCache = new double[8];
private ReadOnlySpan<double> normProbDist(
ReadOnlySpan<double> probDistAll, ReadOnlySpan<Card> possibleCards)
{
double probSum = 0;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK];
for (int i = 0; i < possibleCards.Length; i++)
probSum += probDistCache[i];
double scale = 1 / probSum;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] *= scale;

return probDistCache.AsSpan().Slice(0, possibleCards.Length);
}
}

public struct PPOTrainBatch
{
public PPOTrainBatch(int size, int numStateDims)
public PPOTrainBatch(int size, int numStateDims, int numActionDims)
{
Size = size;
StatesBefore = Matrix2D.Zeros(size, numStateDims);
Actions = Matrix2D.Zeros(size, 1);
Actions = Matrix2D.Zeros(size, numActionDims);
Rewards = Matrix2D.Zeros(size, 1);
Terminals = Matrix2D.Zeros(size, 1);
Returns = Matrix2D.Zeros(size, 1);
Expand Down Expand Up @@ -464,10 +240,18 @@ public PPOTrainBatch SliceRows(int rowid, int length)
};
}

public class PPORolloutBuffer
public class PPORolloutBuffer<TState, TAction>
where TState : IEquatable<TState>, new()
where TAction : IEquatable<TAction>, new()
{
public PPORolloutBuffer(PPOTrainingSettings config)
public PPORolloutBuffer(
PPOTrainingSettings config,
Action<TState, Matrix2D> encodeState,
Action<TAction, Matrix2D> encodeAction)
{
this.encodeState = encodeState;
this.encodeAction = encodeAction;

NumEnvs = config.NumEnvs;
Steps = config.StepsPerUpdate;
gamma = config.RewardDiscount;
Expand All @@ -478,13 +262,19 @@ public PPORolloutBuffer(PPOTrainingSettings config)

int size = Steps * NumEnvs;
int sizeWithExtraStep = (Steps + 1) * NumEnvs;
cache = new PPOTrainBatch(sizeWithExtraStep, config.NumStateDims);
cache = new PPOTrainBatch(
sizeWithExtraStep,
config.NumStateDims,
config.NumActionDims
);
cacheWithoutLastStep = cache.SliceRows(0, size);
cacheOnlyFirstStep = cache.SliceRows(0, NumEnvs);
cacheOnlyLastStep = cache.SliceRows(size, NumEnvs);
permCache = Perm.Identity(size);
}

private Action<TState, Matrix2D> encodeState;
private Action<TAction, Matrix2D> encodeAction;
public int NumEnvs;
public int Steps;
private double gamma;
Expand All @@ -500,7 +290,7 @@ public PPORolloutBuffer(PPOTrainingSettings config)
public int NumBatches(int batchSize, int epochs = 1)
=> cacheWithoutLastStep.Size / batchSize * epochs;

public void AppendStep(PPOExp[] exps, int t)
public void AppendStep(PPOExp<TState, TAction>[] exps, int t)
{
if (exps.Length != NumEnvs)
throw new ArgumentException("Invalid amount of experiences!");
Expand All @@ -514,9 +304,10 @@ public void AppendStep(PPOExp[] exps, int t)
var exp = exps[i];
unsafe
{
var s0Dest = buffer.StatesBefore.SliceRowsRaw(i, 1);
exp.StateBefore.ExportFeatures(s0Dest);
buffer.Actions.Data[i] = exp.Action.Id % 32;
var s0Dest = buffer.StatesBefore.SliceRows(i, 1);
encodeState(exp.StateBefore, s0Dest);
var a0Dest = buffer.Actions.SliceRows(i, 1);
encodeAction(exp.Action, a0Dest);
buffer.Rewards.Data[i] = exp.Reward;
buffer.Terminals.Data[i] = exp.IsTerminal ? 1 : 0;
buffer.OldProbs.Data[i] = exp.OldProb;
Expand Down
5 changes: 5 additions & 0 deletions Schafkopf.Training/CardPicker/Experience.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace Schafkopf.Training;

public class SchafkopfSarsExp : SarsExp<GameState, Card> { }

public class SchafkopfPPOExp : PPOExp<GameState, Card> { }
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static GameState[] NewBuffer()
=> Enumerable.Range(0, 36).Select(x => new GameState()).ToArray();

private GameState[] stateBuffer = NewBuffer();
public void SerializeSarsExps(GameLog completedGame, SarsExp[] exps)
public void SerializeSarsExps(GameLog completedGame, SchafkopfSarsExp[] exps)
{
if (completedGame.CardCount != 32)
throw new ArgumentException("Can only process finished games!");
Expand Down
Loading

0 comments on commit ce5c8c4

Please sign in to comment.