Skip to content

Commit

Permalink
add cart pole training demo (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 13, 2024
1 parent 069571c commit 1954a78
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 283 deletions.
203 changes: 203 additions & 0 deletions RLNetDemo/CartPole.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
namespace RLNetDemo;

using BackpropNet;
using Schafkopf.Training;

public record struct CartPoleState(
double x,
double x_dot,
double theta,
double theta_dot
);

public enum CartPoleDirection
{
Right = 1,
Left = 0
}

public record struct CartPoleAction(
CartPoleDirection Direction
);

public class CartPoleEnv : MDPEnv<CartPoleState, CartPoleAction>
{
// OpenAI reference implementation:
// https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py

private const double gravity = 9.8;
private const double masscart = 1.0;
private const double masspole = 0.1;
private const double total_mass = masspole + masscart;
private const double length = 0.5;
private const double polemass_length = masspole * length;
private const double force_mag = 10.0;
private const double tau = 0.02;
private const double theta_threshold_radians = 12.0 * 2.0 * Math.PI / 360.0;
private const double x_threshold = 2.4;

private CartPoleState high = new CartPoleState(
x: x_threshold * 2.0,
x_dot: float.MaxValue,
theta: theta_threshold_radians * 2.0,
theta_dot: float.MaxValue
);

private CartPoleState low = new CartPoleState(
x: x_threshold * -2.0,
x_dot: float.MinValue,
theta: theta_threshold_radians * -2.0,
theta_dot: float.MinValue
);

private CartPoleState? state = null;
private Random rng = new Random(0);

public void Seed(int seed)
=> rng = new Random(seed);

public (CartPoleState, double, bool) Step(CartPoleAction action)
{
if (state == null)
throw new InvalidOperationException("Environment needs to be initialized with Reset()");

(var x, var x_dot, var theta, var theta_dot) = state.Value;

var force = action.Direction == CartPoleDirection.Right ? force_mag : -force_mag;
var costheta = Math.Cos(theta);
var sintheta = Math.Sin(theta);

var temp = (force + polemass_length * Math.Pow(theta_dot, 2) * sintheta) / total_mass;
var thetaacc = (gravity * sintheta - costheta * temp) / (
length * (4.0 / 3.0 - masspole * Math.Pow(costheta, 2) / total_mass)
);
var xacc = temp - polemass_length * thetaacc * costheta / total_mass;

// Euler interpolation
state = new CartPoleState(
x + tau * x_dot,
x_dot + tau * xacc,
theta + tau * theta_dot,
theta_dot + tau * thetaacc
);

var terminated =
x < -x_threshold
|| x > x_threshold
|| theta < -theta_threshold_radians
|| theta > theta_threshold_radians;

var reward = 1.0;
return (state.Value, reward, terminated);
}

public CartPoleState Reset()
{
state = new CartPoleState(
sample(low.x, high.x),
sample(low.x_dot, high.x_dot),
sample(low.theta, high.theta),
sample(low.theta_dot, high.theta_dot)
);
return state.Value;
}

private double sample(double low, double high)
=> rng.NextDouble() * (high - low) + low;
}

public class CartPolePPOAdapter : IPPOAdapter<CartPoleState, CartPoleAction>
{
public CartPolePPOAdapter(PPOTrainingSettings config)
{
actionsCache = Enumerable.Range(0, config.NumEnvs)
.Select(x => new CartPoleAction()).ToArray();
}

private CartPoleAction[] actionsCache;

public void EncodeState(CartPoleState s0, Matrix2D buf)
{
var cache = buf.SliceRowsRaw(0, 1);
cache[0] = s0.x;
cache[1] = s0.x_dot;
cache[2] = s0.theta;
cache[3] = s0.theta_dot;
}

public void EncodeAction(CartPoleAction a0, Matrix2D buf)
{
buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction;
}

public IList<CartPoleAction> SampleActions(Matrix2D pi)
{
for (int i = 0; i < pi.NumRows; i++)
actionsCache[i].Direction = (CartPoleDirection)(int)pi.At(i, 0);
return actionsCache;
}
}

public record CartPoleBenchmarkStats(
// TODO: figure out other interesting stats to benchmark
double AvgEpSteps,
double AvgEpRewards
);

public class CartPoleBenchmark
{
public CartPoleBenchmark(
PPOTrainingSettings config,
Func<CartPoleEnv> envFactory)
{
this.config = config;
this.envFactory = envFactory;
}

private readonly PPOTrainingSettings config;
private readonly Func<CartPoleEnv> envFactory;

public CartPoleBenchmarkStats Benchmark(PPOModel model, int totalEpisodes)
{
var adapter = new CartPolePPOAdapter(config);
var agent = new VecorizedPPOAgent<CartPoleState, CartPoleAction>(
adapter.EncodeState, adapter.SampleActions, config
);
var vecEnv = new VectorizedEnv<CartPoleState, CartPoleAction>(
Enumerable.Range(0, config.NumEnvs).Select(i => envFactory()).ToArray()
);

int ep = 0;
var rewardCaches = new double[vecEnv.NumEnvs];
var stepCaches = new int[vecEnv.NumEnvs];
var epSteps = new double[totalEpisodes];
var epRewards = new double[totalEpisodes];

var states = vecEnv.Reset();

while (ep < totalEpisodes)
{
var actions = agent.PickActions(model, states);
(states, var rewards, var terminals) = vecEnv.Step(actions);

for (int i = 0; i < vecEnv.NumEnvs; i++)
{
rewardCaches[i] += rewards[i];
stepCaches[i] += 1;

if (terminals[i])
{
epRewards[ep] = rewardCaches[i];
epSteps[ep] = stepCaches[i];

if (++ep == totalEpisodes)
break;
}
}
}

return new CartPoleBenchmarkStats(
epSteps.Average(), epRewards.Average()
);
}
}
41 changes: 41 additions & 0 deletions RLNetDemo/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using RLNetDemo;
using Schafkopf.Training;

public class Program
{
public static void Main(string[] args)
{
var config = new PPOTrainingSettings() {
NumStateDims = 4,
NumActionDims = 2,
TotalSteps = 2_000_000,
StepsPerUpdate = 2048
};
var model = new PPOModel(config);

var adapter = new CartPolePPOAdapter(config);
var rolloutBuffer = new PPORolloutBuffer<CartPoleState, CartPoleAction>(config, adapter);
var expCollector = new SingleAgentExpCollector<CartPoleState, CartPoleAction>(
config, adapter, () => new CartPoleEnv()
);

var benchmark = new CartPoleBenchmark(config, () => new CartPoleEnv());

Console.WriteLine("benchmark untrained model");
var res = benchmark.Benchmark(model, 1_000);
Console.WriteLine($"avg. rewards: {res.AvgEpRewards}, avg. steps: {res.AvgEpSteps}");

for (int ep = 0; ep < config.NumTrainings; ep++)
{
Console.WriteLine($"starting episode {ep+1}/{config.NumTrainings}");
Console.WriteLine("collect rollout buffer");
expCollector.Collect(rolloutBuffer, model);
Console.WriteLine("train on rollout buffer");
model.Train(rolloutBuffer);
Console.WriteLine("benchmark model");
res = benchmark.Benchmark(model, 1_000);
Console.WriteLine($"avg. rewards: {res.AvgEpRewards}, avg. steps: {res.AvgEpSteps}");
Console.WriteLine("===============================");
}
}
}
15 changes: 15 additions & 0 deletions RLNetDemo/RLNetDemo.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">

<ItemGroup>
<ProjectReference Include="..\Schafkopf.Training\Schafkopf.Training.csproj" />
<ProjectReference Include="..\BackpropNet\BackpropNet\BackpropNet.csproj" />
</ItemGroup>

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

</Project>
6 changes: 6 additions & 0 deletions Schafkopf.AI.sln
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Schafkopf.Lib.Benchmarks",
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Schafkopf.Training.Tests", "Schafkopf.Training.Tests\Schafkopf.Training.Tests.csproj", "{BE812100-DF7C-45F3-B7CC-37A7847588AB}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RLNetDemo", "RLNetDemo\RLNetDemo.csproj", "{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -42,5 +44,9 @@ Global
{BE812100-DF7C-45F3-B7CC-37A7847588AB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BE812100-DF7C-45F3-B7CC-37A7847588AB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BE812100-DF7C-45F3-B7CC-37A7847588AB}.Release|Any CPU.Build.0 = Release|Any CPU
{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3BBB9EC8-2DB2-4691-988F-5FFE2DFE8AD5}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
EndGlobal
Loading

0 comments on commit 1954a78

Please sign in to comment.