From 5d443851dc76c8a2949ac88361fbbfb2659f5dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Tue, 10 Dec 2024 17:06:04 +0100 Subject: [PATCH] add cartpole ppo training benchmark (wip) --- .../PPOTrainingSessionTests.cs | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 Schafkopf.Training.Tests/PPOTrainingSessionTests.cs diff --git a/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs b/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs new file mode 100644 index 0000000..aeba1ff --- /dev/null +++ b/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs @@ -0,0 +1,252 @@ +namespace Schafkopf.Training.Tests; + +public class PPOTraining_CartPole_Tests +{ + [Fact] + public void Inference_WithRandomAgent_DoesNotThrowException() + { + var rng = new Random(42); + + var env = new CartPoleEnv(); + env.Reset(); + for (int i = 0; i < 10_000; i++) + env.Step(new CartPoleAction() { Direction = (CartPoleDirection)rng.Next(0, 2) }); + + Assert.True(true); // just ensure no exception occurs + } + + [Fact(Skip="test requires further debugging")] + public void Training_WithPPOAgent_CanLearnCartPole() + { + var config = new PPOTrainingSettings() { + NumStateDims = 4, NumActionDims = 2 + }; + var model = new PPOModel(config); + var envFactory = () => new CartPoleEnv(); + + var encodeState = (CartPoleState s0, Matrix2D buf) => { + var cache = buf.SliceRowsRaw(0, 1); + cache[0] = s0.x; + cache[1] = s0.x_dot; + cache[2] = s0.theta; + cache[3] = s0.theta_dot; + }; + var encodeAction = (CartPoleAction a0, Matrix2D buf) => { + buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction; + }; + + // TODO: model one-hot sampling as a new class + var uniform = new UniformDistribution(); + var actionsCache = Enumerable.Range(0, config.NumEnvs) + .Select(x => new CartPoleAction()).ToArray(); + var probsCache = new double[config.NumEnvs]; + var sampleActions = (Matrix2D piOH) => { + for (int i = 0; i < piOH.NumRows; i++) + { + var probDist = piOH.SliceRowsRaw(i, 1); + var idx = uniform.Sample(probDist); + actionsCache[i].Direction = (CartPoleDirection)idx; + probsCache[i] = probDist[idx]; + } + return ((IList)actionsCache, (IList)probsCache); + }; + + var rollout = new PPORolloutBuffer( + config, encodeState, encodeAction + ); + var exps = new SingleAgentExpCollector( + config, encodeState, sampleActions, envFactory + ); + + for (int ep = 0; ep < config.NumTrainings; ep++) + { + exps.Collect(rollout, model); + model.Train(rollout); + } + + // TODO: make assertion, e.g. avg_steps < x + } +} + +public class SingleAgentExpCollector + where TState : IEquatable, new() + where TAction : IEquatable, new() +{ + public SingleAgentExpCollector( + PPOTrainingSettings config, + Action encodeState, + Func, IList)> sampleActions, + Func> envFactory) + { + this.config = config; + this.encodeState = encodeState; + this.sampleActions = sampleActions; + + var envs = Enumerable.Range(0, config.NumEnvs) + .Select(i => envFactory()).ToList(); + vecEnv = new VectorizedEnv(envs); + exps = Enumerable.Range(0, config.NumEnvs) + .Select(i => new PPOExp()).ToArray(); + s0 = vecEnv.Reset().ToArray(); + + s0_enc = Matrix2D.Zeros(config.NumEnvs, config.NumStateDims); + v = Matrix2D.Zeros(config.NumEnvs, 1); + pi = Matrix2D.Zeros(config.NumEnvs, config.NumActionDims); + } + + private readonly PPOTrainingSettings config; + private readonly VectorizedEnv vecEnv; + private readonly Action encodeState; + private readonly Func, IList)> sampleActions; + + private TState[] s0; + private PPOExp[] exps; + private Matrix2D s0_enc; + private Matrix2D v; + private Matrix2D pi; + + public void Collect(PPORolloutBuffer buffer, PPOModel model) + { + for (int t = 0; t < buffer.Steps; t++) + { + for (int i = 0; i < config.NumEnvs; i++) + encodeState(s0[i], s0_enc.SliceRows(i, 1)); + + model.Predict(s0_enc, v, pi); + (var a0, var p_a0) = sampleActions(pi); + + (var s1, var r1, var t1) = vecEnv.Step(a0); + + for (int i = 0; i < config.NumEnvs; i++) + { + exps[i].StateBefore = s0[i]; + exps[i].Action = a0[i]; + exps[i].Reward = r1[i]; + exps[i].IsTerminal = t1[i]; + exps[i].OldProb = p_a0[i]; + exps[i].OldBaseline = v.SliceRowsRaw(i, 1)[0]; + } + + for (int i = 0; i < config.NumEnvs; i++) + s0[i] = s1[i]; + + buffer.AppendStep(exps, t); + } + } +} + +public record struct CartPoleState( + double x, + double x_dot, + double theta, + double theta_dot +); + +public enum CartPoleDirection { Right = 1, Left = 0 } + +public record struct CartPoleAction( + CartPoleDirection Direction +); + +public class VectorizedEnv +{ + private readonly IList> envs; + private IList states; + private IList rewards; + private IList terminals; + + public VectorizedEnv(IList> envs) + { + this.envs = envs; + states = new TState[envs.Count]; + rewards = new double[envs.Count]; + terminals = new bool[envs.Count]; + } + + public IList Reset() + { + for (int i = 0; i < envs.Count; i++) + states[i] = envs[0].Reset(); + return states; + } + + public (IList, IList, IList) Step(IList actions) + { + for (int i = 0; i < envs.Count; i++) + { + (var s1, var r1, var t1) = envs[i].Step(actions[i]); + s1 = t1 ? envs[i].Reset() : s1; + states[i] = s1; + rewards[i] = r1; + terminals[i] = t1; + } + + return (states, rewards, terminals); + } +} + +public class CartPoleEnv : MDPEnv +{ + private const double gravity = 9.8; + private const double masscart = 1.0; + private const double masspole = 0.1; + private const double total_mass = masspole + masscart; + private const double length = 0.5; + private const double polemass_length = masspole * length; + private const double force_mag = 10.0; + private const double tau = 0.02; + private const double theta_threshold_radians = 12 * 2 * Math.PI / 360; + private const double x_threshold = 2.4; + + private CartPoleState? state = null; + private Random rng = new Random(0); + + public (CartPoleState, double, bool) Step(CartPoleAction action) + { + if (state == null) + throw new InvalidOperationException("Environment needs to be initialized with Reset()"); + + (var x, var x_dot, var theta, var theta_dot) = state.Value; + + var force = action.Direction == CartPoleDirection.Right ? force_mag : -force_mag; + var costheta = Math.Cos(theta); + var sintheta = Math.Sin(theta); + + var temp = (force + polemass_length * Math.Pow(theta_dot, 2) * sintheta) / total_mass; + var thetaacc = (gravity * sintheta - costheta * temp) / ( + length * (4.0 / 3.0 - masspole * Math.Pow(costheta, 2) / total_mass) + ); + var xacc = temp - polemass_length * thetaacc * costheta / total_mass; + + // Euler interpolation + state = new CartPoleState( + x + tau * x_dot, + x_dot + tau * xacc, + theta + tau * theta_dot, + theta_dot + tau * thetaacc + ); + + var terminated = + x < -x_threshold + || x > x_threshold + || theta < -theta_threshold_radians + || theta > theta_threshold_radians; + + var reward = 1.0; + return (state.Value, reward, terminated); + } + + public CartPoleState Reset() + { + state = new CartPoleState( + sample(x_threshold * -2, x_threshold * 2), + sample(-0.05, 0.05), + sample(theta_threshold_radians * -2, theta_threshold_radians * 2), + sample(-0.05, 0.05) + ); + return state.Value; + } + + private double sample(double low, double high) + => rng.NextDouble() * (high - low) + low; +}