Skip to content

Commit

Permalink
add first draft for an actor critic agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 20, 2023
1 parent 93cf4f9 commit b69d209
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 6 deletions.
130 changes: 130 additions & 0 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

// TODO: train a policy to predict the likelihood
// of selecting an action in a given state

public class UniformDistribution
{
private static readonly Random rng = new Random();

public static int Sample(ReadOnlySpan<double> probs)
{
double p = rng.NextDouble();
double sum = 0;
for (int i = 0; i < probs.Length - 1; i++)
{
sum += probs[i];
if (p < sum)
return i;
}
return probs.Length - 1;
}
}

public class PPOAgent : ISchafkopfAIAgent
{
private FFModel valueModel = new FFModel(new ILayer[] {
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(1),
new FlattenLayer()
});

private FFModel strategyModel =
new FFModel(
new ILayer[] {
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(1),
new FlattenLayer(),
new SoftmaxLayer()
});

private GameStateSerializer stateSerializer = new GameStateSerializer();
private Matrix2D featureCache = Matrix2D.Zeros(8, 92);
public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
{
var x = featureCache;
var s0 = stateSerializer.SerializeState(log);

int p = 0;
for (int i = 0; i < possibleCards.Length; i++)
{
unsafe
{
var card = possibleCards[i];
x.Data[p++] = GameEncoding.Encode(card.Type);
x.Data[p++] = GameEncoding.Encode(card.Color);
s0.ExportFeatures(x.Data + p);
p += GameState.NUM_FEATURES;
}
}

var probDist = strategyModel.PredictBatch(featureCache);
ReadOnlySpan<double> probDistSlice;
unsafe { probDistSlice = new Span<double>(probDist.Data, possibleCards.Length); }
int id = UniformDistribution.Sample(probDistSlice);
return possibleCards[id];
}

public void OnGameFinished(GameLog final)
{
throw new NotImplementedException();
}

#region Misc

public bool CallKontra(GameLog log) => false;

public bool CallRe(GameLog log) => false;

public bool IsKlopfer(int position, ReadOnlySpan<Card> firstFourCards) => false;

private HeuristicGameCaller caller =
new HeuristicGameCaller(new GameMode[] { GameMode.Sauspiel });
public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> caller.MakeCall(possibleCalls, position, hand, klopfer);

#endregion Misc
}

public class PPOTrainingSettings
{
public int NumObsFeatures { get; set; }
public int TotalSteps = 10_000_000;
public double LearnRate = 3e-4;
public double RewardDiscount = 0.99;
public double GAEDiscount = 0.95;
public double ProbClip = 0.2;
public double ValueClip = 0.2;
public double VFCoef = 0.5;
public double EntCoef = 0.01;
public bool NormAdvantages = true;
public bool ClipValues = true;
public int BatchSize = 64;
public int NumEnvs = 32;
public int StepsPerUpdate = 512;
public int UpdateEpochs = 4;
public int NumModelSnapshots = 20;

public int TrainSteps => TotalSteps / NumEnvs;
public int ModelSnapshotInterval => TrainSteps / NumModelSnapshots;
}

public class PPOTrainingSession
{
public void Train()
{

}
}

public class PPORolloutBuffer
{
//
}
18 changes: 12 additions & 6 deletions Schafkopf.Training/GameState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public void SerializeSarsExps(
GameLog completedGame, SarsExp[] exps,
Func<GameLog, int, double> reward)
{
if (completedGame.CardCount != 32)
throw new ArgumentException("Can only process finished games!");
serializeHistory(completedGame, stateBuffer);

var actions = completedGame.UnrollActions().GetEnumerator();
Expand All @@ -61,16 +63,18 @@ public void SerializeSarsExps(
}
}

public GameState SerializeState(GameLog liveGame)
{
serializeHistory(liveGame, stateBuffer);
return stateBuffer[liveGame.CardCount - 1];
}

private int playerPosOfTurn(GameLog log, int t_id, int p_id)
=> t_id == 8 ? p_id : normPlayerId(p_id, log.Turns[t_id].FirstDrawingPlayerId);

private void serializeHistory(GameLog completedGame, GameState[] statesCache)
{
if (completedGame.CardCount != 32)
throw new ArgumentException("Can only process finished games!");
if (statesCache.Length < 36)
throw new ArgumentException("");

int timesteps = completedGame.CardCount;
var origCall = completedGame.Call;
var hands = completedGame.UnrollHands().GetEnumerator();
var scores = completedGame.UnrollAugen().GetEnumerator();
Expand All @@ -81,7 +85,7 @@ private void serializeHistory(GameLog completedGame, GameState[] statesCache)
};

int t = 0;
foreach (var turn in completedGame.Turns)
for (int t_id = 0; t_id < Math.Ceiling((double)timesteps / 4); t_id++)
{
scores.MoveNext();

Expand All @@ -93,6 +97,8 @@ private void serializeHistory(GameLog completedGame, GameState[] statesCache)
var score = scores.Current;
var state = statesCache[t].State;
serializeState(state, normCalls, hand, t++, allActions, score);

if (t == timesteps) return;
}
}

Expand Down
59 changes: 59 additions & 0 deletions Schafkopf.Training/NeuralNet/Layers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,62 @@ public void ApplyGrads()
// info: layer has no trainable params
}
}

public class FlattenLayer : ILayer
{
public FlattenLayer(int axis = 1)
=> this.axis = axis;

private int axis;

public LayerCache Cache { get; private set; }

public int InputDims { get; private set; }

public int OutputDims { get; private set; }

public void Compile(int inputDims)
{
InputDims = inputDims;
}

public void CompileCache(Matrix2D inputs, Matrix2D deltasOut)
{
int flatDims = inputs.NumRows * inputs.NumCols;
OutputDims = axis == 0 ? flatDims : 1;
int batchSize = axis == 0 ? 1 : flatDims;

Cache = new LayerCache() {
Input = inputs,
Output = Matrix2D.Zeros(batchSize, OutputDims),
DeltasIn = Matrix2D.Zeros(batchSize, OutputDims),
DeltasOut = deltasOut,
Gradients = Matrix2D.Null()
};
}

public void Forward()
{
unsafe
{
int dataLen = Cache.Input.NumRows * Cache.Input.NumCols;
for (int i = 0; i < dataLen; i++)
Cache.Output.Data[i] = Cache.Input.Data[i];
}
}

public void Backward()
{
unsafe
{
int dataLen = Cache.DeltasIn.NumRows * Cache.DeltasIn.NumCols;
for (int i = 0; i < dataLen; i++)
Cache.DeltasOut.Data[i] = Cache.DeltasIn.Data[i];
}
}

public void ApplyGrads()
{
// info: layer has no trainable params
}
}
24 changes: 24 additions & 0 deletions Schafkopf.Training/RandomAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
{
if (allowedModes.Contains(GameMode.Solo))
{
var call = canCallSolo(possibleCalls, position, hand, klopfer);
if (call.Mode == GameMode.Solo)
return call;
}

if (allowedModes.Contains(GameMode.Wenz))
{
var call = canCallWenz(possibleCalls, position, hand, klopfer);
if (call.Mode == GameMode.Wenz)
return call;
}

if (allowedModes.Contains(GameMode.Sauspiel))
{
var call = canCallSauspiel(possibleCalls, hand);
Expand Down Expand Up @@ -57,6 +71,16 @@ private GameCall canCallSauspiel(

return sauspielCalls.OrderBy(x => hand.FarbeCount(x.GsuchteFarbe)).First();
}

private GameCall canCallSolo(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> GameCall.Weiter(); // TODO: implement logic for solo decision

private GameCall canCallWenz(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> GameCall.Weiter(); // TODO: implement logic for wenz decision
}

public class RandomAgent : ISchafkopfAIAgent
Expand Down

0 comments on commit b69d209

Please sign in to comment.