Skip to content

Commit

Permalink
Merge branch 'save_world_state' into dev
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/pycram/datastructures/world.py
  • Loading branch information
AbdelrhmanBassiouny committed Oct 11, 2024
2 parents 79142d8 + f90dcc9 commit 97169d0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/pycram/datastructures/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ class WorldState(State):
"""
Dataclass for storing the state of the world.
"""
simulator_state_id: Optional[int]
object_states: Dict[str, ObjectState]
simulator_state_id: Optional[int] = None

def __eq__(self, other: 'WorldState'):
return (self.simulator_state_is_equal(other) and self.all_objects_exist(other)
Expand Down Expand Up @@ -572,8 +572,9 @@ def all_objects_states_are_equal(self, other: 'WorldState') -> bool:
other.object_states.values())])

def __copy__(self):
return WorldState(simulator_state_id=self.simulator_state_id,
object_states=deepcopy(self.object_states))
return WorldState(object_states=deepcopy(self.object_states),
simulator_state_id=self.simulator_state_id
)


@dataclass
Expand Down
40 changes: 33 additions & 7 deletions src/pycram/datastructures/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, mode: WorldMode, is_prospection_world: bool = False, clear_ca
"""

StateEntity.__init__(self)
self.latest_state_id: Optional[int] = None

if clear_cache or (self.conf.clear_cache_at_start and not self.cache_manager.cache_cleared):
self.cache_manager.clear_cache()
Expand Down Expand Up @@ -1000,28 +1001,52 @@ def save_state(self, state_id: Optional[int] = None, use_same_id: bool = False)
:param use_same_id: Whether to use the same current state id for the new saved state.
:return: A unique id of the state
"""
state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id)

sim_state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id)

if state_id is None:
if self.latest_state_id is None:
self.latest_state_id = 0
else:
self.latest_state_id += 0 if use_same_id else 1
state_id = self.latest_state_id

self.save_objects_state(state_id)
self._current_state = WorldState(state_id, self.object_states)

self._current_state = WorldState(self.object_states, sim_state_id)

return super().save_state(state_id)

@property
def current_state(self) -> WorldState:
if self._current_state is None:
simulator_state = None if self.conf.use_physics_simulator_state else (
simulator_state_id = None if not self.conf.use_physics_simulator_state else (
self.save_physics_simulator_state(use_same_id=True))
self._current_state = WorldState(simulator_state, self.object_states)
return WorldState(self._current_state.simulator_state_id, self.object_states)
self._current_state = WorldState(self.object_states, simulator_state_id)
return WorldState(self.object_states, self._current_state.simulator_state_id)

@current_state.setter
def current_state(self, state: WorldState) -> None:
if self.current_state != state:
if self.conf.use_physics_simulator_state:
self.restore_physics_simulator_state(state.simulator_state_id)
self.set_object_states_without_poses(state.object_states)
else:
for obj in self.objects:
self.get_object_by_name(obj.name).current_state = state.object_states[obj.name]

def set_object_states_without_poses(self, states: Dict[str, ObjectState]) -> None:
"""
Set the states of all objects in the World except the poses.
:param states: A dictionary with the object id as key and the object state as value.
"""
for obj_name, obj_state in states.items():
obj = self.get_object_by_name(obj_name)
obj.set_attachments(obj_state.attachments)
obj.link_states = obj_state.link_states
obj.joint_states = obj_state.joint_states

@property
def object_states(self) -> Dict[str, ObjectState]:
"""
Expand Down Expand Up @@ -1197,15 +1222,16 @@ def reset_world(self, remove_saved_states=False) -> None:
self.restore_state(self.original_state_id)
if remove_saved_states:
self.remove_saved_states()
self.original_state_id = self.save_state()
self.save_state(use_same_id=True)

def remove_saved_states(self) -> None:
"""
Remove all saved states of the World.
"""
if self.conf.use_physics_simulator_state:
for state_id in self.saved_states:
self.remove_physics_simulator_state(state_id)
if state_id is not None:
self.remove_physics_simulator_state(state_id)
else:
self.remove_objects_saved_states()
super().remove_saved_states()
Expand Down
27 changes: 27 additions & 0 deletions test/test_multiverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,33 @@ def tearDownClass(cls):
def tearDown(self):
self.multiverse.remove_all_objects()

def test_save_and_restore_state(self):
milk = self.spawn_milk([1, 1, 0.1])
robot = self.spawn_robot()
cup = self.spawn_cup([1, 2, 0.1])
apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment.urdf")
apartment.set_joint_position("cabinet10_drawer1_joint", 0.1)
robot.attach(milk)
milk.attach(cup)
all_object_attachments = {obj: obj.attachments.copy() for obj in self.multiverse.objects}
state_id = self.multiverse.save_state()
milk.detach(cup)
robot_link = robot.root_link
milk_link = milk.root_link
cid = robot_link.constraint_ids[milk_link]
self.assertTrue(cid == robot.attachments[milk].id)
self.multiverse.remove_constraint(cid)
apartment.set_joint_position("cabinet10_drawer1_joint", 0.0)
self.multiverse.restore_state(state_id)
cid = robot_link.constraint_ids[milk_link]
self.assertTrue(milk_link in robot_link.constraint_ids)
self.assertTrue(cid == robot.attachments[milk].id)
for obj in self.multiverse.objects:
self.assertTrue(len(obj.attachments) == len(all_object_attachments[obj]))
for att in obj.attachments:
self.assertTrue(att in all_object_attachments[obj])
self.assertTrue(apartment.get_joint_position("cabinet10_drawer1_joint") == 0.1)

def test_spawn_xml_object(self):
bread = Object("bread_1", ObjectType.GENERIC_OBJECT, "bread_1.xml", pose=Pose([1, 1, 0.1]))
self.assert_poses_are_equal(bread.get_pose(), Pose([1, 1, 0.1]))
Expand Down

0 comments on commit 97169d0

Please sign in to comment.