Skip to content

Commit

Permalink
restructure code
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 30, 2023
1 parent 0b2e226 commit c24f281
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 308 deletions.
72 changes: 36 additions & 36 deletions Schafkopf.Training.Tests/ReplayMemoryTests.cs
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@
namespace Schafkopf.Training.Tests;
// namespace Schafkopf.Training.Tests;

public class ReplayMemoryTests
{
[Fact]
public void Test_CanFillCacheUntilOverflow()
{
var memory = new ReplayMemory(100);
// public class ReplayMemoryTests
// {
// [Fact]
// public void Test_CanFillCacheUntilOverflow()
// {
// var memory = new ReplayMemory(100);

for (int i = 0; i < 50; i++)
memory.Append(new SarsExp());
// for (int i = 0; i < 50; i++)
// memory.Append(new SarsExp());

Assert.Equal(50, memory.Size);
// Assert.Equal(50, memory.Size);

for (int i = 0; i < 50; i++)
memory.Append(new SarsExp());
// for (int i = 0; i < 50; i++)
// memory.Append(new SarsExp());

Assert.Equal(100, memory.Size);
}
// Assert.Equal(100, memory.Size);
// }

[Fact]
public void Test_CanInsertIntoOverflowingCache()
{
var memory = new ReplayMemory(100);
// [Fact]
// public void Test_CanInsertIntoOverflowingCache()
// {
// var memory = new ReplayMemory(100);

for (int i = 0; i < 200; i++)
memory.Append(new SarsExp());
// for (int i = 0; i < 200; i++)
// memory.Append(new SarsExp());

Assert.Equal(100, memory.Size);
}
// Assert.Equal(100, memory.Size);
// }

[Fact(Skip = "requires the states to be initialized with unique data")]
public void Test_CanReplaceOverflowingDataWithNewData()
{
var memory = new ReplayMemory(100);
var overflowingData = Enumerable.Range(0, 50).Select(x => new SarsExp()).ToArray();
var insertedData = Enumerable.Range(0, 100).Select(x => new SarsExp()).ToArray();
// [Fact(Skip = "requires the states to be initialized with unique data")]
// public void Test_CanReplaceOverflowingDataWithNewData()
// {
// var memory = new ReplayMemory(100);
// var overflowingData = Enumerable.Range(0, 50).Select(x => new SarsExp()).ToArray();
// var insertedData = Enumerable.Range(0, 100).Select(x => new SarsExp()).ToArray();

foreach (var exp in overflowingData)
memory.Append(exp);
foreach (var exp in insertedData)
memory.Append(exp);
// foreach (var exp in overflowingData)
// memory.Append(exp);
// foreach (var exp in insertedData)
// memory.Append(exp);

Assert.True(overflowingData.All(exp => !memory.Contains(exp)));
Assert.True(insertedData.All(exp => memory.Contains(exp)));
}
}
// Assert.True(overflowingData.All(exp => !memory.Contains(exp)));
// Assert.True(insertedData.All(exp => memory.Contains(exp)));
// }
// }
24 changes: 24 additions & 0 deletions Schafkopf.Training/Algos/Distributions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace Schafkopf.Training;

public class UniformDistribution
{
public UniformDistribution(int? seed = null)
=> rng = seed != null ? new Random(seed.Value) : new Random();

private Random rng;

public int Sample(ReadOnlySpan<double> probs)
{
double p = rng.NextDouble();
double sum = 0;
for (int i = 0; i < probs.Length - 1; i++)
{
sum += probs[i];
if (p < sum)
return i;
}
return probs.Length - 1;
}

public int Sample(int numClasses) => rng.Next(0, numClasses);
}
37 changes: 35 additions & 2 deletions Schafkopf.Training/Algos/HeuristicAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,57 @@ namespace Schafkopf.Training;

public class HeuristicAgent : ISchafkopfAIAgent
{
private Random rng = new Random();
private HeuristicGameCaller caller =
new HeuristicGameCaller(new GameMode[] { GameMode.Sauspiel });
private HeuristicCardPicker cardPicker = new HeuristicCardPicker();

public GameCall MakeCall(
ReadOnlySpan<GameCall> possibleCalls,
int position, Hand hand, int klopfer)
=> caller.MakeCall(possibleCalls, position, hand, klopfer);

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
=> possibleCards[rng.Next(0, possibleCards.Length)];
=> cardPicker.ChooseCard(log, possibleCards);

public bool IsKlopfer(int position, ReadOnlySpan<Card> firstFourCards) => false;
public bool CallKontra(GameLog log) => false;
public bool CallRe(GameLog log) => false;
public void OnGameFinished(GameLog final) => throw new NotImplementedException();
}

public class HeuristicCardPicker
{
private Random rng = new Random();

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
{
if (log.Call.Mode == GameMode.Solo)
return chooseCardForSolo(log, possibleCards);
else if (log.Call.Mode == GameMode.Wenz)
return chooseCardForWenz(log, possibleCards);
else // if (log.Call.Mode == GameMode.Sauspiel)
return chooseCardForSauspiel(log, possibleCards);
}

private Card chooseCardForSolo(GameLog log, ReadOnlySpan<Card> possibleCards)
{
// TODO: implement heuristic
return possibleCards[rng.Next(0, possibleCards.Length)];
}

private Card chooseCardForWenz(GameLog log, ReadOnlySpan<Card> possibleCards)
{
// TODO: implement heuristic
return possibleCards[rng.Next(0, possibleCards.Length)];
}

private Card chooseCardForSauspiel(GameLog log, ReadOnlySpan<Card> possibleCards)
{
// TODO: implement heuristic
return possibleCards[rng.Next(0, possibleCards.Length)];
}
}

public class HeuristicGameCaller
{
public HeuristicGameCaller(IEnumerable<GameMode> modes)
Expand Down
102 changes: 70 additions & 32 deletions Schafkopf.Training/Algos/MDP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,24 @@ namespace Schafkopf.Training;

public class CardPickerExpCollector
{
public CardPickerExpCollector(
PPOModel strategy, PossibleCardPicker cardSampler)
public CardPickerExpCollector(PPOModel strategy)
{
this.strategy = strategy;
this.cardSampler = cardSampler;
}

private GameRules rules = new GameRules();
private GameStateSerializer stateSerializer = new GameStateSerializer();
private PPOModel strategy;
private PossibleCardPicker cardSampler;

private struct TurnBatches
{
public TurnBatches(int numSessions)
{
s0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 90)).ToArray();
a0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
piBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
piSparseBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
vBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
}

public Matrix2D[] s0Batches { get; set; }
public Matrix2D[] a0Batches { get; set; }
public Matrix2D[] piBatches { get; set; }
public Matrix2D[] piSparseBatches { get; set; }
public Matrix2D[] vBatches { get; set; }
}
private PossibleCardPicker cardSampler = new PossibleCardPicker();

public void Collect(PPORolloutBuffer buffer)
{
if (buffer.NumEnvs % 4 != 0)
throw new ArgumentException("The number of envs needs to be "
+ "divisible by 4 because 4 agents are playing the game!");
if (buffer.Steps % 8 != 0)
throw new ArgumentException("The number of steps needs to be "
+ "divisible by 8 because each agent plays 8 cards per game!");
// if (buffer.Steps % 8 != 0)
// throw new ArgumentException("The number of steps needs to be "
// + "divisible by 8 because each agent plays 8 cards per game!");

int numGames = buffer.Steps / 8;
int numSessions = buffer.NumEnvs / 4;
Expand Down Expand Up @@ -95,7 +70,7 @@ private void fillBuffer(
unsafe
{
expBuf.Actions.Data[rowid] = a0Batch.Data[envId];
expBuf.Rewards.Data[rowid] = rewards.Data[envId];
expBuf.Rewards.Data[rowid] = r1Batch.Data[envId];
expBuf.Terminals.Data[rowid] = t_id == 7 ? 1 : 0;
expBuf.OldProbs.Data[rowid] = piSparseBatch.Data[envId];
expBuf.OldBaselines.Data[rowid] = vBatch.Data[envId];
Expand Down Expand Up @@ -136,7 +111,7 @@ private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batc
for (int envId = 0; envId < states.Length; envId++)
{
var s0 = stateSerializer.SerializeState(states[envId]);
unsafe { s0.ExportFeatures(s0Batch.Data + envId * 90); }
s0.ExportFeatures(s0Batch.SliceRowsRaw(envId, 1));
}

strategy.Predict(s0Batch, piBatch, vBatch);
Expand Down Expand Up @@ -164,6 +139,69 @@ private void playGame(CardPickerEnv[] envs, GameLog[] states, TurnBatches[] batc
}
}
}

private struct TurnBatches
{
public TurnBatches(int numSessions)
{
s0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 90)).ToArray();
a0Batches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
piBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
piSparseBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 32)).ToArray();
vBatches = Enumerable.Range(0, 4)
.Select(i => Matrix2D.Zeros(numSessions, 1)).ToArray();
}

public Matrix2D[] s0Batches { get; set; }
public Matrix2D[] a0Batches { get; set; }
public Matrix2D[] piBatches { get; set; }
public Matrix2D[] piSparseBatches { get; set; }
public Matrix2D[] vBatches { get; set; }
}

private class PossibleCardPicker
{
private UniformDistribution uniform = new UniformDistribution();

public Card PickCard(
ReadOnlySpan<Card> possibleCards,
ReadOnlySpan<double> predPi,
Card sampledCard)
=> canPlaySampledCard(possibleCards, sampledCard) ? sampledCard
: possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

public Card PickCard(ReadOnlySpan<Card> possibleCards, ReadOnlySpan<double> predPi)
=> possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))];

private bool canPlaySampledCard(
ReadOnlySpan<Card> possibleCards, Card sampledCard)
{
foreach (var card in possibleCards)
if (card == sampledCard)
return true;
return false;
}

private double[] probDistCache = new double[8];
private ReadOnlySpan<double> normProbDist(
ReadOnlySpan<double> probDistAll, ReadOnlySpan<Card> possibleCards)
{
double probSum = 0;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK];
for (int i = 0; i < possibleCards.Length; i++)
probSum += probDistCache[i];
double scale = 1 / probSum;
for (int i = 0; i < possibleCards.Length; i++)
probDistCache[i] *= scale;

return probDistCache.AsSpan().Slice(0, possibleCards.Length);
}
}
}

public class CardPickerEnv
Expand Down
Loading

0 comments on commit c24f281

Please sign in to comment.