Skip to content

Commit

Permalink
add env id layer to order the training exps
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Nov 24, 2023
1 parent 6ee9ea2 commit 9ccecdb
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,22 @@ private ReadOnlySpan<double> normProbDist(
}
}

public class VectorizedCardPicker : ICardPicker
public class EnvCardPicker : ICardPicker
{
public EnvCardPicker(int envId, VectorizedCardPicker vecAgent)
{
this.envId = envId;
this.vecAgent = vecAgent;
}

private int envId;
private VectorizedCardPicker vecAgent;

public Card ChooseCard(GameLog log, ReadOnlySpan<Card> possibleCards)
=> vecAgent.ChooseCard(envId, log, possibleCards);
}

public class VectorizedCardPicker
{
public VectorizedCardPicker(
PPOTrainingSettings config,
Expand All @@ -288,10 +303,7 @@ public VectorizedCardPicker(
for (int i = 0; i < inputs.Length; i++)
inputs[i].Item2 = new Card[8];
outputs = new Card[config.BatchSize];

count = 0;
barr = new Barrier(config.BatchSize, (b) => predictBatched());
mut = new Mutex();

s0 = Matrix2D.Zeros(config.BatchSize, 90);
a0 = Matrix2D.Zeros(config.BatchSize, 1);
Expand All @@ -313,26 +325,16 @@ public VectorizedCardPicker(
private Matrix2D s0, a0, piOnehot, piSparse, V;
private PPOTrainBatch expCache;

private int count;
private Barrier barr;
private Mutex mut = new Mutex();

public Card ChooseCard(
GameLog log, ReadOnlySpan<Card> possibleCards)
int envId, GameLog log, ReadOnlySpan<Card> possibleCards)
{
mut.WaitOne();

// TODO: don't draw random i, use env id
int i = count;
inputs[i].Item1 = log;
possibleCards.CopyTo(inputs[i].Item2);
inputs[i].Item3 = possibleCards.Length;
count++;

mut.ReleaseMutex();
inputs[envId].Item1 = log;
possibleCards.CopyTo(inputs[envId].Item2);
inputs[envId].Item3 = possibleCards.Length;
barr.SignalAndWait();

return outputs[i];
return outputs[envId];
}

private void predictBatched()
Expand Down Expand Up @@ -378,7 +380,6 @@ private void predictBatched()
// Matrix2D.CopyData(cacheOnlyLastStep.Rewards, expCache.Rewards);
// Matrix2D.CopyData(cacheOnlyLastStep.Terminals, expCache.Terminals);


memory.AppendStep(expCache, t++);
}
}
Expand All @@ -393,14 +394,17 @@ public void Train()
var ppoModel = new PPOModel(config);
var cardPicker = new VectorizedCardPicker(config, ppoModel, memory);
var heuristicGameCaller = new HeuristicAgent();
var agent = new ComposedAgent(heuristicGameCaller, cardPicker);
var envProxies = Enumerable.Range(0, config.NumEnvs)
.Select(i => new EnvCardPicker(i, cardPicker)).ToArray();
var agents = Enumerable.Range(0, config.NumEnvs)
.Select(i => new ComposedAgent(heuristicGameCaller, envProxies[i])).ToArray();

var tables = Enumerable.Range(0, config.NumEnvs)
.Select(i => new Table(
new Player(0, agent),
new Player(1, agent),
new Player(2, agent),
new Player(3, agent)
new Player(0, agents[i]),
new Player(1, agents[i]),
new Player(2, agents[i]),
new Player(3, agents[i])
)).ToArray();
var sessions = Enumerable.Range(0, config.NumEnvs)
.Select(i => new GameSession(tables[i], new CardsDeck())).ToArray();
Expand Down

0 comments on commit 9ccecdb

Please sign in to comment.