Skip to content


add cartpole ppo training benchmark (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 10, 2024
1 parent ce5c8c4 commit 5d44385
Showing 1 changed file with 252 additions and 0 deletions.
252 changes: 252 additions & 0 deletions Schafkopf.Training.Tests/PPOTrainingSessionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
namespace Schafkopf.Training.Tests;

public class PPOTraining_CartPole_Tests
public void Inference_WithRandomAgent_DoesNotThrowException()
var rng = new Random(42);

var env = new CartPoleEnv();
for (int i = 0; i < 10_000; i++)
env.Step(new CartPoleAction() { Direction = (CartPoleDirection)rng.Next(0, 2) });

Assert.True(true); // just ensure no exception occurs

[Fact(Skip="test requires further debugging")]
public void Training_WithPPOAgent_CanLearnCartPole()
var config = new PPOTrainingSettings() {
NumStateDims = 4, NumActionDims = 2
var model = new PPOModel(config);
var envFactory = () => new CartPoleEnv();

var encodeState = (CartPoleState s0, Matrix2D buf) => {
var cache = buf.SliceRowsRaw(0, 1);
cache[0] = s0.x;
cache[1] = s0.x_dot;
cache[2] = s0.theta;
cache[3] = s0.theta_dot;
var encodeAction = (CartPoleAction a0, Matrix2D buf) => {
buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction;

// TODO: model one-hot sampling as a new class
var uniform = new UniformDistribution();
var actionsCache = Enumerable.Range(0, config.NumEnvs)
.Select(x => new CartPoleAction()).ToArray();
var probsCache = new double[config.NumEnvs];
var sampleActions = (Matrix2D piOH) => {
for (int i = 0; i < piOH.NumRows; i++)
var probDist = piOH.SliceRowsRaw(i, 1);
var idx = uniform.Sample(probDist);
actionsCache[i].Direction = (CartPoleDirection)idx;
probsCache[i] = probDist[idx];
return ((IList<CartPoleAction>)actionsCache, (IList<double>)probsCache);

var rollout = new PPORolloutBuffer<CartPoleState, CartPoleAction>(
config, encodeState, encodeAction
var exps = new SingleAgentExpCollector<CartPoleState, CartPoleAction>(
config, encodeState, sampleActions, envFactory

for (int ep = 0; ep < config.NumTrainings; ep++)
exps.Collect(rollout, model);

// TODO: make assertion, e.g. avg_steps < x

public class SingleAgentExpCollector<TState, TAction>
where TState : IEquatable<TState>, new()
where TAction : IEquatable<TAction>, new()
public SingleAgentExpCollector(
PPOTrainingSettings config,
Action<TState, Matrix2D> encodeState,
Func<Matrix2D, (IList<TAction>, IList<double>)> sampleActions,
Func<MDPEnv<TState, TAction>> envFactory)
this.config = config;
this.encodeState = encodeState;
this.sampleActions = sampleActions;

var envs = Enumerable.Range(0, config.NumEnvs)
.Select(i => envFactory()).ToList();
vecEnv = new VectorizedEnv<TState, TAction>(envs);
exps = Enumerable.Range(0, config.NumEnvs)
.Select(i => new PPOExp<TState, TAction>()).ToArray();
s0 = vecEnv.Reset().ToArray();

s0_enc = Matrix2D.Zeros(config.NumEnvs, config.NumStateDims);
v = Matrix2D.Zeros(config.NumEnvs, 1);
pi = Matrix2D.Zeros(config.NumEnvs, config.NumActionDims);

private readonly PPOTrainingSettings config;
private readonly VectorizedEnv<TState, TAction> vecEnv;
private readonly Action<TState, Matrix2D> encodeState;
private readonly Func<Matrix2D, (IList<TAction>, IList<double>)> sampleActions;

private TState[] s0;
private PPOExp<TState, TAction>[] exps;
private Matrix2D s0_enc;
private Matrix2D v;
private Matrix2D pi;

public void Collect(PPORolloutBuffer<TState, TAction> buffer, PPOModel model)
for (int t = 0; t < buffer.Steps; t++)
for (int i = 0; i < config.NumEnvs; i++)
encodeState(s0[i], s0_enc.SliceRows(i, 1));

model.Predict(s0_enc, v, pi);
(var a0, var p_a0) = sampleActions(pi);

(var s1, var r1, var t1) = vecEnv.Step(a0);

for (int i = 0; i < config.NumEnvs; i++)
exps[i].StateBefore = s0[i];
exps[i].Action = a0[i];
exps[i].Reward = r1[i];
exps[i].IsTerminal = t1[i];
exps[i].OldProb = p_a0[i];
exps[i].OldBaseline = v.SliceRowsRaw(i, 1)[0];

for (int i = 0; i < config.NumEnvs; i++)
s0[i] = s1[i];

buffer.AppendStep(exps, t);

public record struct CartPoleState(
double x,
double x_dot,
double theta,
double theta_dot

public enum CartPoleDirection { Right = 1, Left = 0 }

public record struct CartPoleAction(
CartPoleDirection Direction

public class VectorizedEnv<TState, TAction>
private readonly IList<MDPEnv<TState, TAction>> envs;
private IList<TState> states;
private IList<double> rewards;
private IList<bool> terminals;

public VectorizedEnv(IList<MDPEnv<TState, TAction>> envs)
this.envs = envs;
states = new TState[envs.Count];
rewards = new double[envs.Count];
terminals = new bool[envs.Count];

public IList<TState> Reset()
for (int i = 0; i < envs.Count; i++)
states[i] = envs[0].Reset();
return states;

public (IList<TState>, IList<double>, IList<bool>) Step(IList<TAction> actions)
for (int i = 0; i < envs.Count; i++)
(var s1, var r1, var t1) = envs[i].Step(actions[i]);
s1 = t1 ? envs[i].Reset() : s1;
states[i] = s1;
rewards[i] = r1;
terminals[i] = t1;

return (states, rewards, terminals);

public class CartPoleEnv : MDPEnv<CartPoleState, CartPoleAction>
private const double gravity = 9.8;
private const double masscart = 1.0;
private const double masspole = 0.1;
private const double total_mass = masspole + masscart;
private const double length = 0.5;
private const double polemass_length = masspole * length;
private const double force_mag = 10.0;
private const double tau = 0.02;
private const double theta_threshold_radians = 12 * 2 * Math.PI / 360;
private const double x_threshold = 2.4;

private CartPoleState? state = null;
private Random rng = new Random(0);

public (CartPoleState, double, bool) Step(CartPoleAction action)
if (state == null)
throw new InvalidOperationException("Environment needs to be initialized with Reset()");

(var x, var x_dot, var theta, var theta_dot) = state.Value;

var force = action.Direction == CartPoleDirection.Right ? force_mag : -force_mag;
var costheta = Math.Cos(theta);
var sintheta = Math.Sin(theta);

var temp = (force + polemass_length * Math.Pow(theta_dot, 2) * sintheta) / total_mass;
var thetaacc = (gravity * sintheta - costheta * temp) / (
length * (4.0 / 3.0 - masspole * Math.Pow(costheta, 2) / total_mass)
var xacc = temp - polemass_length * thetaacc * costheta / total_mass;

// Euler interpolation
state = new CartPoleState(
x + tau * x_dot,
x_dot + tau * xacc,
theta + tau * theta_dot,
theta_dot + tau * thetaacc

var terminated =
x < -x_threshold
|| x > x_threshold
|| theta < -theta_threshold_radians
|| theta > theta_threshold_radians;

var reward = 1.0;
return (state.Value, reward, terminated);

public CartPoleState Reset()
state = new CartPoleState(
sample(x_threshold * -2, x_threshold * 2),
sample(-0.05, 0.05),
sample(theta_threshold_radians * -2, theta_threshold_radians * 2),
sample(-0.05, 0.05)
return state.Value;

private double sample(double low, double high)
=> rng.NextDouble() * (high - low) + low;

0 comments on commit 5d44385

Please sign in to comment.