diff --git a/src/pycram/datastructures/dataclasses.py b/src/pycram/datastructures/dataclasses.py index 79d998f12..61ff1581b 100644 --- a/src/pycram/datastructures/dataclasses.py +++ b/src/pycram/datastructures/dataclasses.py @@ -534,8 +534,8 @@ class WorldState(State): """ Dataclass for storing the state of the world. """ - simulator_state_id: Optional[int] object_states: Dict[str, ObjectState] + simulator_state_id: Optional[int] = None def __eq__(self, other: 'WorldState'): return (self.simulator_state_is_equal(other) and self.all_objects_exist(other) @@ -572,8 +572,9 @@ def all_objects_states_are_equal(self, other: 'WorldState') -> bool: other.object_states.values())]) def __copy__(self): - return WorldState(simulator_state_id=self.simulator_state_id, - object_states=deepcopy(self.object_states)) + return WorldState(object_states=deepcopy(self.object_states), + simulator_state_id=self.simulator_state_id + ) @dataclass diff --git a/src/pycram/datastructures/world.py b/src/pycram/datastructures/world.py index 537a92625..ecc3738b1 100644 --- a/src/pycram/datastructures/world.py +++ b/src/pycram/datastructures/world.py @@ -85,6 +85,7 @@ def __init__(self, mode: WorldMode, is_prospection_world: bool = False, clear_ca """ StateEntity.__init__(self) + self.latest_state_id: Optional[int] = None if clear_cache or (self.conf.clear_cache_at_start and not self.cache_manager.cache_cleared): self.cache_manager.clear_cache() @@ -1000,28 +1001,52 @@ def save_state(self, state_id: Optional[int] = None, use_same_id: bool = False) :param use_same_id: Whether to use the same current state id for the new saved state. :return: A unique id of the state """ - state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id) + + sim_state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id) + + if state_id is None: + if self.latest_state_id is None: + self.latest_state_id = 0 + else: + self.latest_state_id += 0 if use_same_id else 1 + state_id = self.latest_state_id + self.save_objects_state(state_id) - self._current_state = WorldState(state_id, self.object_states) + + self._current_state = WorldState(self.object_states, sim_state_id) + return super().save_state(state_id) @property def current_state(self) -> WorldState: if self._current_state is None: - simulator_state = None if self.conf.use_physics_simulator_state else ( + simulator_state_id = None if not self.conf.use_physics_simulator_state else ( self.save_physics_simulator_state(use_same_id=True)) - self._current_state = WorldState(simulator_state, self.object_states) - return WorldState(self._current_state.simulator_state_id, self.object_states) + self._current_state = WorldState(self.object_states, simulator_state_id) + return WorldState(self.object_states, self._current_state.simulator_state_id) @current_state.setter def current_state(self, state: WorldState) -> None: if self.current_state != state: if self.conf.use_physics_simulator_state: self.restore_physics_simulator_state(state.simulator_state_id) + self.set_object_states_without_poses(state.object_states) else: for obj in self.objects: self.get_object_by_name(obj.name).current_state = state.object_states[obj.name] + def set_object_states_without_poses(self, states: Dict[str, ObjectState]) -> None: + """ + Set the states of all objects in the World except the poses. + + :param states: A dictionary with the object id as key and the object state as value. + """ + for obj_name, obj_state in states.items(): + obj = self.get_object_by_name(obj_name) + obj.set_attachments(obj_state.attachments) + obj.link_states = obj_state.link_states + obj.joint_states = obj_state.joint_states + @property def object_states(self) -> Dict[str, ObjectState]: """ @@ -1197,7 +1222,7 @@ def reset_world(self, remove_saved_states=False) -> None: self.restore_state(self.original_state_id) if remove_saved_states: self.remove_saved_states() - self.original_state_id = self.save_state() + self.save_state(use_same_id=True) def remove_saved_states(self) -> None: """ @@ -1205,7 +1230,8 @@ def remove_saved_states(self) -> None: """ if self.conf.use_physics_simulator_state: for state_id in self.saved_states: - self.remove_physics_simulator_state(state_id) + if state_id is not None: + self.remove_physics_simulator_state(state_id) else: self.remove_objects_saved_states() super().remove_saved_states() diff --git a/test/test_multiverse.py b/test/test_multiverse.py index 693643221..4026bbb07 100644 --- a/test/test_multiverse.py +++ b/test/test_multiverse.py @@ -53,6 +53,33 @@ def tearDownClass(cls): def tearDown(self): self.multiverse.remove_all_objects() + def test_save_and_restore_state(self): + milk = self.spawn_milk([1, 1, 0.1]) + robot = self.spawn_robot() + cup = self.spawn_cup([1, 2, 0.1]) + apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment.urdf") + apartment.set_joint_position("cabinet10_drawer1_joint", 0.1) + robot.attach(milk) + milk.attach(cup) + all_object_attachments = {obj: obj.attachments.copy() for obj in self.multiverse.objects} + state_id = self.multiverse.save_state() + milk.detach(cup) + robot_link = robot.root_link + milk_link = milk.root_link + cid = robot_link.constraint_ids[milk_link] + self.assertTrue(cid == robot.attachments[milk].id) + self.multiverse.remove_constraint(cid) + apartment.set_joint_position("cabinet10_drawer1_joint", 0.0) + self.multiverse.restore_state(state_id) + cid = robot_link.constraint_ids[milk_link] + self.assertTrue(milk_link in robot_link.constraint_ids) + self.assertTrue(cid == robot.attachments[milk].id) + for obj in self.multiverse.objects: + self.assertTrue(len(obj.attachments) == len(all_object_attachments[obj])) + for att in obj.attachments: + self.assertTrue(att in all_object_attachments[obj]) + self.assertTrue(apartment.get_joint_position("cabinet10_drawer1_joint") == 0.1) + def test_spawn_xml_object(self): bread = Object("bread_1", ObjectType.GENERIC_OBJECT, "bread_1.xml", pose=Pose([1, 1, 0.1])) self.assert_poses_are_equal(bread.get_pose(), Pose([1, 1, 0.1]))