From da33f11442874b450474bd82d772931423f54284 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Mon, 11 Nov 2024 17:30:38 +0800 Subject: [PATCH] Got the gui working. Time for gradient descent by grad student Signed-off-by: Arjo Chakravarty --- .../simple_cart_pole/README.md | 7 +++ .../simple_cart_pole/cart_pole_env.py | 23 ++++---- python/CMakeLists.txt | 2 + python/src/gz/sim/Gui.cc | 52 +++++++++++++++++++ python/src/gz/sim/Gui.hh | 39 ++++++++++++++ python/src/gz/sim/_gz_sim_pybind11.cc | 2 + 6 files changed, 115 insertions(+), 10 deletions(-) create mode 100644 python/src/gz/sim/Gui.cc create mode 100644 python/src/gz/sim/Gui.hh diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/README.md b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md index 2e8f396809..53b2dae9e1 100644 --- a/examples/scripts/reinforcement_learning/simple_cart_pole/README.md +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md @@ -19,6 +19,12 @@ Lets install our dependencies ``` pip install stable-baselines3[extra] ``` +For visuallization to work you will also need to: +``` +pip uninstall opencv-python +pip install opencv-python-headless +``` +This is because `opencv-python` brings in Qt5 by default. In the same terminal you should add your gazebo python install directory to the `PYTHONPATH` If you built gazebo from source in the current working directory this would be: @@ -32,6 +38,7 @@ mis-matches. export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python ``` + ## Exploring the environment You can see the environment by using `gz sim cart_pole.sdf`. \ No newline at end of file diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py index f94792d314..31d336b817 100644 --- a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py @@ -4,14 +4,15 @@ import numpy as np from gz.common6 import set_verbosity -from gz.sim9 import TestFixture, World, world_entity, Model, Link +from gz.sim9 import TestFixture, World, world_entity, Model, Link, run_gui from gz.math8 import Vector3d from gz.transport14 import Node from gz.msgs11.world_control_pb2 import WorldControl from gz.msgs11.world_reset_pb2 import WorldReset from gz.msgs11.boolean_pb2 import Boolean -from stable_baselines3 import A2C +from stable_baselines3 import PPO +import time file_path = os.path.dirname(os.path.realpath(__file__)) @@ -36,9 +37,9 @@ def on_pre_update(self, info, ecm): self.chassis = Link(self.chassis_entity) self.chassis.enable_velocity_checks(ecm) if self.command == 1: - self.chassis.add_world_force(Vector3d(0, 100, 0)) + self.chassis.add_world_force(ecm, Vector3d(2000, 0, 0)) elif self.command == 0: - self.chassis.add_world_force(Vector3d(0, -100, 0)) + self.chassis.add_world_force(ecm, Vector3d(-2000, 0, 0)) def on_post_update(self, info, ecm): pole_pose = self.pole.world_pose(ecm).rot().euler().y() @@ -58,7 +59,7 @@ def on_post_update(self, info, ecm): self.state = np.array([cart_pose, cart_vel, pole_pose, pole_angular_vel], dtype=np.float32) if not self.terminated: - self.terminated = pole_pose > 0.24 or pole_pose < -0.24 or cart_pose > 4.8 or cart_pose < -4.8 + self.terminated = pole_pose > 0.48 or pole_pose < -0.48 or cart_pose > 4.8 or cart_pose < -4.8 if self.terminated: self.reward = 0.0 @@ -66,8 +67,8 @@ def on_post_update(self, info, ecm): self.reward = 1.0 def step(self, action, paused=False): - self.action = action - self.server.run(True, 1, paused) + self.command = action + self.server.run(True, 5, paused) obs = self.state reward = self.reward return obs, reward, self.terminated, False, {} @@ -100,12 +101,14 @@ def step(self, action): env = CustomCartPole({}) -model = A2C("MlpPolicy", env, verbose=1) -model.learn(total_timesteps=10_000) +model = PPO("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=25_000) vec_env = model.get_env() obs = vec_env.reset() -for i in range(5000): +run_gui() +time.sleep(10) +for i in range(50000): action, _state = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) # Nice to have spawn a gz sim client diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 98bbe66650..d5ed701042 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -42,10 +42,12 @@ pybind11_add_module(${BINDINGS_MODULE_NAME} MODULE src/gz/sim/UpdateInfo.cc src/gz/sim/Util.cc src/gz/sim/World.cc + src/gz/sim/Gui.cc ) target_link_libraries(${BINDINGS_MODULE_NAME} PRIVATE ${PROJECT_LIBRARY_TARGET_NAME} + ${PROJECT_LIBRARY_TARGET_NAME}-gui gz-common${GZ_COMMON_VER}::gz-common${GZ_COMMON_VER} ) diff --git a/python/src/gz/sim/Gui.cc b/python/src/gz/sim/Gui.cc new file mode 100644 index 0000000000..c0a04abd3c --- /dev/null +++ b/python/src/gz/sim/Gui.cc @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include + +#include +#include + +#include "Server.hh" + +namespace gz +{ +namespace sim +{ +namespace python +{ +void defineGuiClient(pybind11::module &_module) +{ + + _module.def("run_gui", [](){ + auto pid = fork(); + if (pid == -1) + { + gzerr << "Failed to instantiate new process"; + return; + } + if (pid != 0) + { + return; + } + int zero = 0; + gz::sim::gui::runGui(zero, nullptr, nullptr); + }, + "Run the gui"); +} +} // namespace python +} // namespace sim +} // namespace gz diff --git a/python/src/gz/sim/Gui.hh b/python/src/gz/sim/Gui.hh new file mode 100644 index 0000000000..1856f7b18f --- /dev/null +++ b/python/src/gz/sim/Gui.hh @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef GZ_SIM_PYTHON___HH_ +#define GZ_SIM_PYTHON__SERVER_HH_ + +#include + +namespace gz +{ +namespace sim +{ +namespace python +{ +/// Define a pybind11 wrapper for a gz::sim::Server +/** + * \param[in] module a pybind11 module to add the definition to + */ +void +defineGuiClient(pybind11::module &_module); +} // namespace python +} // namespace sim +} // namespace gz + +#endif // GZ_SIM_PYTHON__SERVER_HH_ \ No newline at end of file diff --git a/python/src/gz/sim/_gz_sim_pybind11.cc b/python/src/gz/sim/_gz_sim_pybind11.cc index acf9373dc5..9645cd66bf 100644 --- a/python/src/gz/sim/_gz_sim_pybind11.cc +++ b/python/src/gz/sim/_gz_sim_pybind11.cc @@ -32,6 +32,7 @@ #include "UpdateInfo.hh" #include "Util.hh" #include "World.hh" +#include "Gui.hh" PYBIND11_MODULE(BINDINGS_MODULE_NAME, m) { m.doc() = "Gazebo Sim Python Library."; @@ -50,4 +51,5 @@ PYBIND11_MODULE(BINDINGS_MODULE_NAME, m) { gz::sim::python::defineSimUpdateInfo(m); gz::sim::python::defineSimWorld(m); gz::sim::python::defineSimUtil(m); + gz::sim::python::defineGuiClient(m); }