Skip to content

Commit

Permalink
[SaveState] save and restore joint states manually when using physics…
Browse files Browse the repository at this point in the history
… simulator restore state as the simulator does not take into account the joint values.
  • Loading branch information
AbdelrhmanBassiouny committed Oct 11, 2024
1 parent 1b70ee4 commit f90dcc9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 26 deletions.
16 changes: 6 additions & 10 deletions src/pycram/datastructures/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,12 @@ class WorldState(State):
"""
Dataclass for storing the state of the world.
"""
object_states: Dict[str, ObjectState]
simulator_state_id: Optional[int] = None
object_states: Optional[Dict[str, ObjectState]] = None

def __eq__(self, other: 'WorldState'):
if self.object_states is None and other.object_states is None:
return self.simulator_state_is_equal(other)
elif self.object_states is None or other.object_states is None:
return False
else:
return (self.simulator_state_is_equal(other) and self.all_objects_exist(other)
and self.all_objects_states_are_equal(other))
return (self.simulator_state_is_equal(other) and self.all_objects_exist(other)
and self.all_objects_states_are_equal(other))

def simulator_state_is_equal(self, other: 'WorldState') -> bool:
"""
Expand Down Expand Up @@ -447,8 +442,9 @@ def all_objects_states_are_equal(self, other: 'WorldState') -> bool:
other.object_states.values())])

def __copy__(self):
return WorldState(simulator_state_id=self.simulator_state_id,
object_states=deepcopy(self.object_states))
return WorldState(object_states=deepcopy(self.object_states),
simulator_state_id=self.simulator_state_id
)


@dataclass
Expand Down
40 changes: 24 additions & 16 deletions src/pycram/datastructures/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,21 +962,16 @@ def save_state(self, state_id: Optional[int] = None, use_same_id: bool = False)

sim_state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id)

if self.conf.use_physics_simulator_state:
object_states = None
else:
if state_id is None:
if self.latest_state_id is None:
self.latest_state_id = 0
else:
self.latest_state_id += 0 if use_same_id else 1
state_id = self.latest_state_id

self.save_objects_state(state_id)
if state_id is None:
if self.latest_state_id is None:
self.latest_state_id = 0
else:
self.latest_state_id += 0 if use_same_id else 1
state_id = self.latest_state_id

object_states = self.object_states
self.save_objects_state(state_id)

self._current_state = WorldState(sim_state_id, object_states)
self._current_state = WorldState(self.object_states, sim_state_id)

return super().save_state(state_id)

Expand All @@ -985,18 +980,31 @@ def current_state(self) -> WorldState:
if self._current_state is None:
simulator_state_id = None if not self.conf.use_physics_simulator_state else (
self.save_physics_simulator_state(use_same_id=True))
self._current_state = WorldState(simulator_state_id, self.object_states)
return WorldState(self._current_state.simulator_state_id, self.object_states)
self._current_state = WorldState(self.object_states, simulator_state_id)
return WorldState(self.object_states, self._current_state.simulator_state_id)

@current_state.setter
def current_state(self, state: WorldState) -> None:
if self.current_state != state:
if self.conf.use_physics_simulator_state:
self.restore_physics_simulator_state(state.simulator_state_id)
self.set_object_states_without_poses(state.object_states)
else:
for obj in self.objects:
self.get_object_by_name(obj.name).current_state = state.object_states[obj.name]

def set_object_states_without_poses(self, states: Dict[str, ObjectState]) -> None:
"""
Set the states of all objects in the World except the poses.
:param states: A dictionary with the object id as key and the object state as value.
"""
for obj_name, obj_state in states.items():
obj = self.get_object_by_name(obj_name)
obj.set_attachments(obj_state.attachments)
obj.link_states = obj_state.link_states
obj.joint_states = obj_state.joint_states

@property
def object_states(self) -> Dict[str, ObjectState]:
"""
Expand Down Expand Up @@ -1172,7 +1180,7 @@ def reset_world(self, remove_saved_states=False) -> None:
self.restore_state(self.original_state_id)
if remove_saved_states:
self.remove_saved_states()
self.original_state_id = self.save_state()
self.save_state(use_same_id=True)

def remove_saved_states(self) -> None:
"""
Expand Down
27 changes: 27 additions & 0 deletions test/test_multiverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,33 @@ def tearDownClass(cls):
def tearDown(self):
self.multiverse.remove_all_objects()

def test_save_and_restore_state(self):
milk = self.spawn_milk([1, 1, 0.1])
robot = self.spawn_robot()
cup = self.spawn_cup([1, 2, 0.1])
apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment.urdf")
apartment.set_joint_position("cabinet10_drawer1_joint", 0.1)
robot.attach(milk)
milk.attach(cup)
all_object_attachments = {obj: obj.attachments.copy() for obj in self.multiverse.objects}
state_id = self.multiverse.save_state()
milk.detach(cup)
robot_link = robot.root_link
milk_link = milk.root_link
cid = robot_link.constraint_ids[milk_link]
self.assertTrue(cid == robot.attachments[milk].id)
self.multiverse.remove_constraint(cid)
apartment.set_joint_position("cabinet10_drawer1_joint", 0.0)
self.multiverse.restore_state(state_id)
cid = robot_link.constraint_ids[milk_link]
self.assertTrue(milk_link in robot_link.constraint_ids)
self.assertTrue(cid == robot.attachments[milk].id)
for obj in self.multiverse.objects:
self.assertTrue(len(obj.attachments) == len(all_object_attachments[obj]))
for att in obj.attachments:
self.assertTrue(att in all_object_attachments[obj])
self.assertTrue(apartment.get_joint_position("cabinet10_drawer1_joint") == 0.1)

def test_spawn_xml_object(self):
bread = Object("bread_1", ObjectType.GENERIC_OBJECT, "bread_1.xml", pose=Pose([1, 1, 0.1]))
self.assert_poses_are_equal(bread.get_pose(), Pose([1, 1, 0.1]))
Expand Down

0 comments on commit f90dcc9

Please sign in to comment.