Skip to content

Commit

Permalink
Adds support for Reset in test fixture
Browse files Browse the repository at this point in the history
This PR adds support for the Reset API to the test fixture. As `TestFixture`
is one of the main ways one can get access to the ECM in python
when trying to write some scripts for Deep Reinforcement Learning I
realized that without `Reset` supported in the `TestFixture` API, end
users would have a very hard time using our python APIs (which are
actually quite nice). For reference I'm hacking a demo template here:

https://github.com/arjo129/gz_deep_rl_experiments/tree/ionic

Signed-off-by: Arjo Chakravarty <[email protected]>
  • Loading branch information
arjo129 committed Oct 11, 2024
1 parent bbe2cc6 commit 2a658f9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/gz/sim/TestFixture.hh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ class GZ_SIM_VISIBLE TestFixture
public: TestFixture &OnPostUpdate(std::function<void(
const UpdateInfo &, const EntityComponentManager &)> _cb);

/// \brief Wrapper around a system's update callback
/// \param[in] _cb Function to be called every update
/// \return Reference to self.
public: TestFixture &OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb);

/// \brief Finalize all the functions and add fixture to server.
/// Finalize must be called before running the server, otherwise none of the
/// `On*` functions will be called.
Expand Down
11 changes: 11 additions & 0 deletions python/src/gz/sim/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ defineSimTestFixture(pybind11::object module)
),
pybind11::return_value_policy::reference,
"Wrapper around a system's post-update callback"
)
.def(
"on_reset", WrapCallbacks(
[](TestFixture* self, std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
self->OnReset(_cb);
}
),
pybind11::return_value_policy::reference,
"Wrapper around a system's post-update callback"
);
// TODO(ahcorde): This method is not compiling for the following reason:
// The EventManager class has an unordered_map which holds a unique_ptr
Expand Down
29 changes: 28 additions & 1 deletion src/TestFixture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class HelperSystem :
public ISystemConfigure,
public ISystemPreUpdate,
public ISystemUpdate,
public ISystemPostUpdate
public ISystemPostUpdate,
public ISystemReset
{
// Documentation inherited
public: void Configure(
Expand All @@ -50,6 +51,10 @@ class HelperSystem :
public: void PostUpdate(const UpdateInfo &_info,
const EntityComponentManager &_ecm) override;

// Documentation inherited
public: void Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm) override;

/// \brief Function to call every time we configure a world
public: std::function<void(const Entity &_entity,
const std::shared_ptr<const sdf::Element> &_sdf,
Expand All @@ -68,6 +73,10 @@ class HelperSystem :
/// \brief Function to call every post-update
public: std::function<void(const UpdateInfo &,
const EntityComponentManager &)> postUpdateCallback;

/// \brief Reset callback
public: std::function<void(const UpdateInfo &,
EntityComponentManager &)> resetCallback;
};

/////////////////////////////////////////////////
Expand Down Expand Up @@ -105,6 +114,14 @@ void HelperSystem::PostUpdate(const UpdateInfo &_info,
this->postUpdateCallback(_info, _ecm);
}

/////////////////////////////////////////////////
void HelperSystem::Reset(const UpdateInfo &_info,
EntityComponentManager &_ecm)
{
if (this->resetCallback)
this->resetCallback(_info, _ecm);
}

//////////////////////////////////////////////////
class gz::sim::TestFixture::Implementation
{
Expand Down Expand Up @@ -200,6 +217,16 @@ TestFixture &TestFixture::OnPostUpdate(std::function<void(
return *this;
}

//////////////////////////////////////////////////
TestFixture &TestFixture::OnReset(std::function<void(
const UpdateInfo &, EntityComponentManager &)> _cb)
{
if (nullptr != this->dataPtr->helperSystem)
this->dataPtr->helperSystem->resetCallback = std::move(_cb);
return *this;
}


//////////////////////////////////////////////////
std::shared_ptr<Server> TestFixture::Server() const
{
Expand Down

0 comments on commit 2a658f9

Please sign in to comment.