From 1954a78d58f364395cb86aa9726a21252bbcd240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Fri, 13 Dec 2024 23:37:48 +0100 Subject: [PATCH] add cart pole training demo (wip) --- RLNetDemo/CartPole.cs | 203 ++++++++++++++ RLNetDemo/Program.cs | 41 +++ RLNetDemo/RLNetDemo.csproj | 15 ++ Schafkopf.AI.sln | 6 + .../PPOTrainingSessionTests.cs | 247 ------------------ Schafkopf.Training/Algos/PPOAgent.cs | 161 +++++++++++- Schafkopf.Training/CardPicker/PPOAgent.cs | 84 ++++-- Schafkopf.Training/Common/MDP.cs | 11 + 8 files changed, 485 insertions(+), 283 deletions(-) create mode 100644 RLNetDemo/CartPole.cs create mode 100644 RLNetDemo/Program.cs create mode 100644 RLNetDemo/RLNetDemo.csproj delete mode 100644 Schafkopf.Training.Tests/PPOTrainingSessionTests.cs diff --git a/RLNetDemo/CartPole.cs b/RLNetDemo/CartPole.cs new file mode 100644 index 0000000..4314ed1 --- /dev/null +++ b/RLNetDemo/CartPole.cs @@ -0,0 +1,203 @@ +namespace RLNetDemo; + +using BackpropNet; +using Schafkopf.Training; + +public record struct CartPoleState( + double x, + double x_dot, + double theta, + double theta_dot +); + +public enum CartPoleDirection +{ + Right = 1, + Left = 0 +} + +public record struct CartPoleAction( + CartPoleDirection Direction +); + +public class CartPoleEnv : MDPEnv +{ + // OpenAI reference implementation: + // https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py + + private const double gravity = 9.8; + private const double masscart = 1.0; + private const double masspole = 0.1; + private const double total_mass = masspole + masscart; + private const double length = 0.5; + private const double polemass_length = masspole * length; + private const double force_mag = 10.0; + private const double tau = 0.02; + private const double theta_threshold_radians = 12.0 * 2.0 * Math.PI / 360.0; + private const double x_threshold = 2.4; + + private CartPoleState high = new CartPoleState( + x: x_threshold * 2.0, + x_dot: float.MaxValue, + theta: theta_threshold_radians * 2.0, + theta_dot: float.MaxValue + ); + + private CartPoleState low = new CartPoleState( + x: x_threshold * -2.0, + x_dot: float.MinValue, + theta: theta_threshold_radians * -2.0, + theta_dot: float.MinValue + ); + + private CartPoleState? state = null; + private Random rng = new Random(0); + + public void Seed(int seed) + => rng = new Random(seed); + + public (CartPoleState, double, bool) Step(CartPoleAction action) + { + if (state == null) + throw new InvalidOperationException("Environment needs to be initialized with Reset()"); + + (var x, var x_dot, var theta, var theta_dot) = state.Value; + + var force = action.Direction == CartPoleDirection.Right ? force_mag : -force_mag; + var costheta = Math.Cos(theta); + var sintheta = Math.Sin(theta); + + var temp = (force + polemass_length * Math.Pow(theta_dot, 2) * sintheta) / total_mass; + var thetaacc = (gravity * sintheta - costheta * temp) / ( + length * (4.0 / 3.0 - masspole * Math.Pow(costheta, 2) / total_mass) + ); + var xacc = temp - polemass_length * thetaacc * costheta / total_mass; + + // Euler interpolation + state = new CartPoleState( + x + tau * x_dot, + x_dot + tau * xacc, + theta + tau * theta_dot, + theta_dot + tau * thetaacc + ); + + var terminated = + x < -x_threshold + || x > x_threshold + || theta < -theta_threshold_radians + || theta > theta_threshold_radians; + + var reward = 1.0; + return (state.Value, reward, terminated); + } + + public CartPoleState Reset() + { + state = new CartPoleState( + sample(low.x, high.x), + sample(low.x_dot, high.x_dot), + sample(low.theta, high.theta), + sample(low.theta_dot, high.theta_dot) + ); + return state.Value; + } + + private double sample(double low, double high) + => rng.NextDouble() * (high - low) + low; +} + +public class CartPolePPOAdapter : IPPOAdapter +{ + public CartPolePPOAdapter(PPOTrainingSettings config) + { + actionsCache = Enumerable.Range(0, config.NumEnvs) + .Select(x => new CartPoleAction()).ToArray(); + } + + private CartPoleAction[] actionsCache; + + public void EncodeState(CartPoleState s0, Matrix2D buf) + { + var cache = buf.SliceRowsRaw(0, 1); + cache[0] = s0.x; + cache[1] = s0.x_dot; + cache[2] = s0.theta; + cache[3] = s0.theta_dot; + } + + public void EncodeAction(CartPoleAction a0, Matrix2D buf) + { + buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction; + } + + public IList SampleActions(Matrix2D pi) + { + for (int i = 0; i < pi.NumRows; i++) + actionsCache[i].Direction = (CartPoleDirection)(int)pi.At(i, 0); + return actionsCache; + } +} + +public record CartPoleBenchmarkStats( + // TODO: figure out other interesting stats to benchmark + double AvgEpSteps, + double AvgEpRewards +); + +public class CartPoleBenchmark +{ + public CartPoleBenchmark( + PPOTrainingSettings config, + Func envFactory) + { + this.config = config; + this.envFactory = envFactory; + } + + private readonly PPOTrainingSettings config; + private readonly Func envFactory; + + public CartPoleBenchmarkStats Benchmark(PPOModel model, int totalEpisodes) + { + var adapter = new CartPolePPOAdapter(config); + var agent = new VecorizedPPOAgent( + adapter.EncodeState, adapter.SampleActions, config + ); + var vecEnv = new VectorizedEnv( + Enumerable.Range(0, config.NumEnvs).Select(i => envFactory()).ToArray() + ); + + int ep = 0; + var rewardCaches = new double[vecEnv.NumEnvs]; + var stepCaches = new int[vecEnv.NumEnvs]; + var epSteps = new double[totalEpisodes]; + var epRewards = new double[totalEpisodes]; + + var states = vecEnv.Reset(); + + while (ep < totalEpisodes) + { + var actions = agent.PickActions(model, states); + (states, var rewards, var terminals) = vecEnv.Step(actions); + + for (int i = 0; i < vecEnv.NumEnvs; i++) + { + rewardCaches[i] += rewards[i]; + stepCaches[i] += 1; + + if (terminals[i]) + { + epRewards[ep] = rewardCaches[i]; + epSteps[ep] = stepCaches[i]; + + if (++ep == totalEpisodes) + break; + } + } + } + + return new CartPoleBenchmarkStats( + epSteps.Average(), epRewards.Average() + ); + } +} diff --git a/RLNetDemo/Program.cs b/RLNetDemo/Program.cs new file mode 100644 index 0000000..710d183 --- /dev/null +++ b/RLNetDemo/Program.cs @@ -0,0 +1,41 @@ +using RLNetDemo; +using Schafkopf.Training; + +public class Program +{ + public static void Main(string[] args) + { + var config = new PPOTrainingSettings() { + NumStateDims = 4, + NumActionDims = 2, + TotalSteps = 2_000_000, + StepsPerUpdate = 2048 + }; + var model = new PPOModel(config); + + var adapter = new CartPolePPOAdapter(config); + var rolloutBuffer = new PPORolloutBuffer(config, adapter); + var expCollector = new SingleAgentExpCollector( + config, adapter, () => new CartPoleEnv() + ); + + var benchmark = new CartPoleBenchmark(config, () => new CartPoleEnv()); + + Console.WriteLine("benchmark untrained model"); + var res = benchmark.Benchmark(model, 1_000); + Console.WriteLine($"avg. rewards: {res.AvgEpRewards}, avg. steps: {res.AvgEpSteps}"); + + for (int ep = 0; ep < config.NumTrainings; ep++) + { + Console.WriteLine($"starting episode {ep+1}/{config.NumTrainings}"); + Console.WriteLine("collect rollout buffer"); + expCollector.Collect(rolloutBuffer, model); + Console.WriteLine("train on rollout buffer"); + model.Train(rolloutBuffer); + Console.WriteLine("benchmark model"); + res = benchmark.Benchmark(model, 1_000); + Console.WriteLine($"avg. rewards: {res.AvgEpRewards}, avg. steps: {res.AvgEpSteps}"); + Console.WriteLine("==============================="); + } + } +} diff --git a/RLNetDemo/RLNetDemo.csproj b/RLNetDemo/RLNetDemo.csproj new file mode 100644 index 0000000..d1ac8ee --- /dev/null +++ b/RLNetDemo/RLNetDemo.csproj @@ -0,0 +1,15 @@ + + + + + + + + + Exe + net8.0 + enable + enable + + + diff --git a/Schafkopf.AI.sln b/Schafkopf.AI.sln index 9b657fa..84b2982 100644 --- a/Schafkopf.AI.sln +++ b/Schafkopf.AI.sln @@ -13,6 +13,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Schafkopf.Lib.Benchmarks", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Schafkopf.Training.Tests", "Schafkopf.Training.Tests\Schafkopf.Training.Tests.csproj", "{BE812100-DF7C-45F3-B7CC-37A7847588AB}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RLNetDemo", "RLNetDemo\RLNetDemo.csproj", "{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -42,5 +44,9 @@ Global {BE812100-DF7C-45F3-B7CC-37A7847588AB}.Debug|Any CPU.Build.0 = Debug|Any CPU {BE812100-DF7C-45F3-B7CC-37A7847588AB}.Release|Any CPU.ActiveCfg = Release|Any CPU {BE812100-DF7C-45F3-B7CC-37A7847588AB}.Release|Any CPU.Build.0 = Release|Any CPU + {3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs b/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs deleted file mode 100644 index b519e15..0000000 --- a/Schafkopf.Training.Tests/PPOTrainingSessionTests.cs +++ /dev/null @@ -1,247 +0,0 @@ -namespace Schafkopf.Training.Tests; - -public class PPOTraining_CartPole_Tests -{ - [Fact] - public void Inference_WithRandomAgent_DoesNotThrowException() - { - var rng = new Random(42); - - var env = new CartPoleEnv(); - env.Reset(); - for (int i = 0; i < 10_000; i++) - env.Step(new CartPoleAction() { Direction = (CartPoleDirection)rng.Next(0, 2) }); - - Assert.True(true); // just ensure no exception occurs - } - - [Fact(Skip="test requires further debugging")] - public void Training_WithPPOAgent_CanLearnCartPole() - { - var config = new PPOTrainingSettings() { - NumStateDims = 4, NumActionDims = 2 - }; - var model = new PPOModel(config); - var envFactory = () => new CartPoleEnv(); - - var encodeState = (CartPoleState s0, Matrix2D buf) => { - var cache = buf.SliceRowsRaw(0, 1); - cache[0] = s0.x; - cache[1] = s0.x_dot; - cache[2] = s0.theta; - cache[3] = s0.theta_dot; - }; - var encodeAction = (CartPoleAction a0, Matrix2D buf) => { - buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction; - }; - - var actionsCache = Enumerable.Range(0, config.NumEnvs) - .Select(x => new CartPoleAction()).ToArray(); - var sampleActions = (Matrix2D pi) => { - for (int i = 0; i < pi.NumRows; i++) - actionsCache[i].Direction = (CartPoleDirection)(int)pi.At(i, 0); - return (IList)actionsCache; - }; - - var rollout = new PPORolloutBuffer( - config, encodeState, encodeAction - ); - var exps = new SingleAgentExpCollector( - config, encodeState, sampleActions, envFactory - ); - - for (int ep = 0; ep < config.NumTrainings; ep++) - { - exps.Collect(rollout, model); - model.Train(rollout); - } - - // TODO: make assertion, e.g. avg_steps < x - } -} - -public class SingleAgentExpCollector - where TState : IEquatable, new() - where TAction : IEquatable, new() -{ - public SingleAgentExpCollector( - PPOTrainingSettings config, - Action encodeState, - Func> sampleActions, - Func> envFactory) - { - this.config = config; - this.encodeState = encodeState; - this.sampleActions = sampleActions; - - var envs = Enumerable.Range(0, config.NumEnvs) - .Select(i => envFactory()).ToList(); - vecEnv = new VectorizedEnv(envs); - exps = Enumerable.Range(0, config.NumEnvs) - .Select(i => new PPOExp()).ToArray(); - s0 = vecEnv.Reset().ToArray(); - - s0_enc = Matrix2D.Zeros(config.NumEnvs, config.NumStateDims); - v = Matrix2D.Zeros(config.NumEnvs, 1); - pi = Matrix2D.Zeros(config.NumEnvs, 1); - piProbs = Matrix2D.Zeros(config.NumEnvs, 1); - } - - private readonly PPOTrainingSettings config; - private readonly VectorizedEnv vecEnv; - private readonly Action encodeState; - private readonly Func> sampleActions; - - private TState[] s0; - private PPOExp[] exps; - private Matrix2D s0_enc; - private Matrix2D v; - private Matrix2D pi; - private Matrix2D piProbs; - - public void Collect(PPORolloutBuffer buffer, PPOModel model) - { - for (int t = 0; t < buffer.Steps; t++) - { - for (int i = 0; i < config.NumEnvs; i++) - encodeState(s0[i], s0_enc.SliceRows(i, 1)); - - model.Predict(s0_enc, pi, piProbs, v); - var a0 = sampleActions(pi); - - (var s1, var r1, var t1) = vecEnv.Step(a0); - - for (int i = 0; i < config.NumEnvs; i++) - { - exps[i].StateBefore = s0[i]; - exps[i].Action = a0[i]; - exps[i].Reward = r1[i]; - exps[i].IsTerminal = t1[i]; - exps[i].OldProb = piProbs.At(i, 0); - exps[i].OldBaseline = v.At(i, 0); - } - - for (int i = 0; i < config.NumEnvs; i++) - s0[i] = s1[i]; - - buffer.AppendStep(exps, t); - } - } -} - -public record struct CartPoleState( - double x, - double x_dot, - double theta, - double theta_dot -); - -public enum CartPoleDirection { Right = 1, Left = 0 } - -public record struct CartPoleAction( - CartPoleDirection Direction -); - -public class VectorizedEnv -{ - private readonly IList> envs; - private IList states; - private IList rewards; - private IList terminals; - - public VectorizedEnv(IList> envs) - { - this.envs = envs; - states = new TState[envs.Count]; - rewards = new double[envs.Count]; - terminals = new bool[envs.Count]; - } - - public IList Reset() - { - for (int i = 0; i < envs.Count; i++) - states[i] = envs[i].Reset(); - return states; - } - - public (IList, IList, IList) Step(IList actions) - { - for (int i = 0; i < envs.Count; i++) - { - (var s1, var r1, var t1) = envs[i].Step(actions[i]); - s1 = t1 ? envs[i].Reset() : s1; - states[i] = s1; - rewards[i] = r1; - terminals[i] = t1; - } - - return (states, rewards, terminals); - } -} - -public class CartPoleEnv : MDPEnv -{ - private const double gravity = 9.8; - private const double masscart = 1.0; - private const double masspole = 0.1; - private const double total_mass = masspole + masscart; - private const double length = 0.5; - private const double polemass_length = masspole * length; - private const double force_mag = 10.0; - private const double tau = 0.02; - private const double theta_threshold_radians = 12 * 2 * Math.PI / 360; - private const double x_threshold = 2.4; - - private CartPoleState? state = null; - private Random rng = new Random(0); - - public (CartPoleState, double, bool) Step(CartPoleAction action) - { - if (state == null) - throw new InvalidOperationException("Environment needs to be initialized with Reset()"); - - (var x, var x_dot, var theta, var theta_dot) = state.Value; - - var force = action.Direction == CartPoleDirection.Right ? force_mag : -force_mag; - var costheta = Math.Cos(theta); - var sintheta = Math.Sin(theta); - - var temp = (force + polemass_length * Math.Pow(theta_dot, 2) * sintheta) / total_mass; - var thetaacc = (gravity * sintheta - costheta * temp) / ( - length * (4.0 / 3.0 - masspole * Math.Pow(costheta, 2) / total_mass) - ); - var xacc = temp - polemass_length * thetaacc * costheta / total_mass; - - // Euler interpolation - state = new CartPoleState( - x + tau * x_dot, - x_dot + tau * xacc, - theta + tau * theta_dot, - theta_dot + tau * thetaacc - ); - - // TODO: check if this condition is correct - var terminated = - x > -x_threshold - && x < x_threshold - && theta > -theta_threshold_radians - && theta < theta_threshold_radians; - - var reward = 1.0; - return (state.Value, reward, terminated); - } - - public CartPoleState Reset() - { - state = new CartPoleState( - sample(x_threshold * -2, x_threshold * 2), - sample(-10.0, 10.0), - sample(theta_threshold_radians * -2, theta_threshold_radians * 2), - sample(-Math.PI, Math.PI) - ); - return state.Value; - } - - private double sample(double low, double high) - => rng.NextDouble() * (high - low) + low; -} diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index cb9aed3..593b9a3 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -25,6 +25,150 @@ public class PPOTrainingSettings public int NumTrainings => TrainSteps / StepsPerUpdate; } +public class VecorizedPPOAgent +{ + public VecorizedPPOAgent( + Action encodeState, + Func> sampleActions, + PPOTrainingSettings config) + { + this.encodeState = encodeState; + this.sampleActions = sampleActions; + + s0 = Matrix2D.Zeros(config.NumEnvs, config.NumStateDims); + v = Matrix2D.Zeros(config.NumEnvs, 1); + pi = Matrix2D.Zeros(config.NumEnvs, 1); + piProbs = Matrix2D.Zeros(config.NumEnvs, 1); + } + + private readonly Action encodeState; + private readonly Func> sampleActions; + + private readonly Matrix2D s0; + private readonly Matrix2D pi; + private readonly Matrix2D piProbs; + private readonly Matrix2D v; + + public IList PickActions(PPOModel model, IList states) + { + for (int i = 0; i < states.Count; i++) + encodeState(states[i], s0.SliceRows(i, 1)); + + model.Predict(s0, pi, piProbs, v); + return sampleActions(pi); + } + + public (IList, Matrix2D, Matrix2D) PickActionsWithMeta( + PPOModel model, IList states) + { + for (int i = 0; i < states.Count; i++) + encodeState(states[i], s0.SliceRows(i, 1)); + + model.Predict(s0, pi, piProbs, v); + return (sampleActions(pi), piProbs, v); + } +} + +public interface IPPOAdapter +{ + void EncodeState(TState s0, Matrix2D buf); + void EncodeAction(TAction a0, Matrix2D buf); + IList SampleActions(Matrix2D pi); +} + +public class SingleAgentExpCollector + where TState : IEquatable, new() + where TAction : IEquatable, new() +{ + public SingleAgentExpCollector( + PPOTrainingSettings config, + IPPOAdapter adapter, + Func> envFactory) + { + this.config = config; + + var envs = Enumerable.Range(0, config.NumEnvs) + .Select(i => envFactory()).ToList(); + vecEnv = new VectorizedEnv(envs); + exps = Enumerable.Range(0, config.NumEnvs) + .Select(i => new PPOExp()).ToArray(); + s0 = vecEnv.Reset().ToArray(); + agent = new VecorizedPPOAgent( + adapter.EncodeState, adapter.SampleActions, config + ); + } + + private readonly PPOTrainingSettings config; + private readonly VecorizedPPOAgent agent; + private readonly VectorizedEnv vecEnv; + + private TState[] s0; + private PPOExp[] exps; + + public void Collect(PPORolloutBuffer buffer, PPOModel model) + { + for (int t = 0; t < buffer.Steps; t++) + { + (var a0, var piProbs, var v) = agent.PickActionsWithMeta(model, s0); + (var s1, var r1, var t1) = vecEnv.Step(a0); + + for (int i = 0; i < config.NumEnvs; i++) + { + exps[i].StateBefore = s0[i]; + exps[i].Action = a0[i]; + exps[i].Reward = r1[i]; + exps[i].IsTerminal = t1[i]; + exps[i].OldProb = piProbs.At(i, 0); + exps[i].OldBaseline = v.At(i, 0); + } + + for (int i = 0; i < config.NumEnvs; i++) + s0[i] = s1[i]; + + buffer.AppendStep(exps, t); + } + } +} + +public class VectorizedEnv +{ + public VectorizedEnv(IList> envs) + { + this.envs = envs; + states = new TState[envs.Count]; + rewards = new double[envs.Count]; + terminals = new bool[envs.Count]; + } + + private readonly IList> envs; + private IList states; + private IList rewards; + private IList terminals; + + public int NumEnvs => envs.Count; + + public IList Reset() + { + for (int i = 0; i < envs.Count; i++) + states[i] = envs[i].Reset(); + return states; + } + + public (IList, IList, IList) Step(IList actions) + { + for (int i = 0; i < envs.Count; i++) + { + (var s1, var r1, var t1) = envs[i].Step(actions[i]); + s1 = t1 ? envs[i].Reset() : s1; + states[i] = s1; + rewards[i] = r1; + terminals[i] = t1; + } + + return (states, rewards, terminals); + } +} + public class PPOModel { public PPOModel(PPOTrainingSettings config) @@ -246,17 +390,13 @@ public PPOTrainBatch SliceRows(int rowid, int length) } public class PPORolloutBuffer - where TState : IEquatable, new() - where TAction : IEquatable, new() + where TState : IEquatable, new() + where TAction : IEquatable, new() { public PPORolloutBuffer( PPOTrainingSettings config, - Action encodeState, - Action encodeAction) + IPPOAdapter adapter) { - this.encodeState = encodeState; - this.encodeAction = encodeAction; - NumEnvs = config.NumEnvs; Steps = config.StepsPerUpdate; gamma = config.RewardDiscount; @@ -278,8 +418,7 @@ public PPORolloutBuffer( permCache = Perm.Identity(size); } - private Action encodeState; - private Action encodeAction; + private IPPOAdapter adapter; public int NumEnvs; public int Steps; private double gamma; @@ -310,9 +449,9 @@ public void AppendStep(PPOExp[] exps, int t) unsafe { var s0Dest = buffer.StatesBefore.SliceRows(i, 1); - encodeState(exp.StateBefore, s0Dest); + adapter.EncodeState(exp.StateBefore, s0Dest); var a0Dest = buffer.Actions.SliceRows(i, 1); - encodeAction(exp.Action, a0Dest); + adapter.EncodeAction(exp.Action, a0Dest); buffer.Rewards.Data[i] = exp.Reward; buffer.Terminals.Data[i] = exp.IsTerminal ? 1 : 0; buffer.OldProbs.Data[i] = exp.OldProb; diff --git a/Schafkopf.Training/CardPicker/PPOAgent.cs b/Schafkopf.Training/CardPicker/PPOAgent.cs index 2fb8e6a..ef80877 100644 --- a/Schafkopf.Training/CardPicker/PPOAgent.cs +++ b/Schafkopf.Training/CardPicker/PPOAgent.cs @@ -1,37 +1,71 @@ namespace Schafkopf.Training; +// public class CardPickerPPOAdapter : IPPOAdapter +// { +// public CardPickerPPOAdapter(PPOTrainingSettings config) +// => actionsCache = Enumerable.Range(0, config.NumEnvs) +// .Select(i => new Card(0)).ToArray(); + +// private readonly Card[] actionsCache; + +// public void EncodeAction(Card a0, Matrix2D buf) +// { +// throw new NotImplementedException(); +// } + +// public void EncodeState(GameState s0, Matrix2D buf) +// { +// s0.ExportFeatures(buf.SliceRowsRaw(0, 1)); +// } + +// public IList SampleActions(Matrix2D pi) +// { +// for (int i = 0; i < pi.NumRows; i++) +// actionsCache[i].Id = (int)pi.At(i, 0); +// return actionsCache; +// } +// } + public class SchafkopfPPOTrainingSession { public PPOModel Train(PPOTrainingSettings config) { - var model = new PPOModel(config); - var rollout = new PPORolloutBuffer( - config, - (s0, buf) => s0.ExportFeatures(buf.SliceRowsRaw(0, 1)), - (a0, buf) => buf.SliceRowsRaw(0, 1)[0] = a0.Id % 32 - ); - var exps = new CardPickerExpCollector(); - var benchmark = new RandomPlayBenchmark(); - var agent = new SchafkopfPPOAgent(model); - - for (int ep = 0; ep < config.NumTrainings; ep++) - { - Console.WriteLine($"epoch {ep+1}"); - exps.Collect(rollout, model); - model.Train(rollout); - - model.RecompileCache(batchSize: 1); - double winRate = benchmark.Benchmark(agent); - model.RecompileCache(batchSize: config.BatchSize); - - Console.WriteLine($"win rate vs. random agents: {winRate}"); - Console.WriteLine("--------------------------------------"); - } - - return model; + // TODO: implement training procedure + throw new NotImplementedException(); } } +// public class SchafkopfPPOTrainingSession +// { +// public PPOModel Train(PPOTrainingSettings config) +// { +// var model = new PPOModel(config); +// var rollout = new PPORolloutBuffer( +// config, + +// ); +// var exps = new CardPickerExpCollector(); +// var benchmark = new RandomPlayBenchmark(); +// var agent = new SchafkopfPPOAgent(model); + +// for (int ep = 0; ep < config.NumTrainings; ep++) +// { +// Console.WriteLine($"epoch {ep+1}"); +// exps.Collect(rollout, model); +// model.Train(rollout); + +// model.RecompileCache(batchSize: 1); +// double winRate = benchmark.Benchmark(agent); +// model.RecompileCache(batchSize: config.BatchSize); + +// Console.WriteLine($"win rate vs. random agents: {winRate}"); +// Console.WriteLine("--------------------------------------"); +// } + +// return model; +// } +// } + public class SchafkopfPPOAgent : ISchafkopfAIAgent { public SchafkopfPPOAgent(PPOModel model) diff --git a/Schafkopf.Training/Common/MDP.cs b/Schafkopf.Training/Common/MDP.cs index 891f275..7355e9d 100644 --- a/Schafkopf.Training/Common/MDP.cs +++ b/Schafkopf.Training/Common/MDP.cs @@ -2,6 +2,7 @@ namespace Schafkopf.Training; public interface MDPEnv { + void Seed(int seed); StateT Reset(); (StateT, double, bool) Step(ActionT cardToPlay); } @@ -16,6 +17,11 @@ public class CardPickerEnv : MDPEnv private GameLog log; private Hand[] initialHandsCache = new Hand[4]; + public void Seed(int seed) + { + // nothing to do here ... environment is deterministic + } + public GameLog Reset() { kommtRaus = (kommtRaus + 1) % 4; @@ -101,6 +107,11 @@ public void Register(int playerId) threadIds[playerId] = Environment.CurrentManagedThreadId; } + public void Seed(int seed) + { + // nothing to do here ... environment is deterministic + } + public GameLog Reset() { int playerId = playerIdByThread();