From 2a658f9aee9e3713e31abb11767a9e2fb64b2c87 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Fri, 11 Oct 2024 13:20:52 +0800 Subject: [PATCH] Adds support for Reset in test fixture This PR adds support for the Reset API to the test fixture. As `TestFixture` is one of the main ways one can get access to the ECM in python when trying to write some scripts for Deep Reinforcement Learning I realized that without `Reset` supported in the `TestFixture` API, end users would have a very hard time using our python APIs (which are actually quite nice). For reference I'm hacking a demo template here: https://github.com/arjo129/gz_deep_rl_experiments/tree/ionic Signed-off-by: Arjo Chakravarty --- include/gz/sim/TestFixture.hh | 6 ++++++ python/src/gz/sim/TestFixture.cc | 11 +++++++++++ src/TestFixture.cc | 29 ++++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/include/gz/sim/TestFixture.hh b/include/gz/sim/TestFixture.hh index 24fb2298a4..6041062cad 100644 --- a/include/gz/sim/TestFixture.hh +++ b/include/gz/sim/TestFixture.hh @@ -96,6 +96,12 @@ class GZ_SIM_VISIBLE TestFixture public: TestFixture &OnPostUpdate(std::function _cb); + /// \brief Wrapper around a system's update callback + /// \param[in] _cb Function to be called every update + /// \return Reference to self. + public: TestFixture &OnReset(std::function _cb); + /// \brief Finalize all the functions and add fixture to server. /// Finalize must be called before running the server, otherwise none of the /// `On*` functions will be called. diff --git a/python/src/gz/sim/TestFixture.cc b/python/src/gz/sim/TestFixture.cc index 826558fbf4..1fa05d23c3 100644 --- a/python/src/gz/sim/TestFixture.cc +++ b/python/src/gz/sim/TestFixture.cc @@ -83,6 +83,17 @@ defineSimTestFixture(pybind11::object module) ), pybind11::return_value_policy::reference, "Wrapper around a system's post-update callback" + ) + .def( + "on_reset", WrapCallbacks( + [](TestFixture* self, std::function _cb) + { + self->OnReset(_cb); + } + ), + pybind11::return_value_policy::reference, + "Wrapper around a system's post-update callback" ); // TODO(ahcorde): This method is not compiling for the following reason: // The EventManager class has an unordered_map which holds a unique_ptr diff --git a/src/TestFixture.cc b/src/TestFixture.cc index 1d02a900ff..6c984b7996 100644 --- a/src/TestFixture.cc +++ b/src/TestFixture.cc @@ -29,7 +29,8 @@ class HelperSystem : public ISystemConfigure, public ISystemPreUpdate, public ISystemUpdate, - public ISystemPostUpdate + public ISystemPostUpdate, + public ISystemReset { // Documentation inherited public: void Configure( @@ -50,6 +51,10 @@ class HelperSystem : public: void PostUpdate(const UpdateInfo &_info, const EntityComponentManager &_ecm) override; + // Documentation inherited + public: void Reset(const UpdateInfo &_info, + EntityComponentManager &_ecm) override; + /// \brief Function to call every time we configure a world public: std::function &_sdf, @@ -68,6 +73,10 @@ class HelperSystem : /// \brief Function to call every post-update public: std::function postUpdateCallback; + + /// \brief Reset callback + public: std::function resetCallback; }; ///////////////////////////////////////////////// @@ -105,6 +114,14 @@ void HelperSystem::PostUpdate(const UpdateInfo &_info, this->postUpdateCallback(_info, _ecm); } +///////////////////////////////////////////////// +void HelperSystem::Reset(const UpdateInfo &_info, + EntityComponentManager &_ecm) +{ + if (this->resetCallback) + this->resetCallback(_info, _ecm); +} + ////////////////////////////////////////////////// class gz::sim::TestFixture::Implementation { @@ -200,6 +217,16 @@ TestFixture &TestFixture::OnPostUpdate(std::function _cb) +{ + if (nullptr != this->dataPtr->helperSystem) + this->dataPtr->helperSystem->resetCallback = std::move(_cb); + return *this; +} + + ////////////////////////////////////////////////// std::shared_ptr TestFixture::Server() const {