From f4ecec82ad397e33e6964cb126706e60bc311ea3 Mon Sep 17 00:00:00 2001 From: Yuval Tassa Date: Sat, 29 Jun 2024 03:06:07 -0700 Subject: [PATCH] Don't normalize `mjData->qpos` quaternions in-place. PiperOrigin-RevId: 647927542 Change-Id: If689bea15a5f076f486563e019031b5b51624e7d --- dm_control/composer/entity.py | 3 ++- dm_control/locomotion/tasks/reference_pose/tracking_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dm_control/composer/entity.py b/dm_control/composer/entity.py index c475a89e..0a78818f 100644 --- a/dm_control/composer/entity.py +++ b/dm_control/composer/entity.py @@ -480,7 +480,8 @@ def set_pose(self, physics, position=None, quaternion=None): if position is not None: physics.bind(root_joint).qpos[:3] = position if quaternion is not None: - physics.bind(root_joint).qpos[3:] = quaternion + normalised_quaternion = quaternion / np.linalg.norm(quaternion) + physics.bind(root_joint).qpos[3:] = normalised_quaternion else: attachment_frame = mjcf.get_attachment_frame(self.mjcf_model) if attachment_frame is None: diff --git a/dm_control/locomotion/tasks/reference_pose/tracking_test.py b/dm_control/locomotion/tasks/reference_pose/tracking_test.py index 8ba42e35..83a08828 100644 --- a/dm_control/locomotion/tasks/reference_pose/tracking_test.py +++ b/dm_control/locomotion/tasks/reference_pose/tracking_test.py @@ -215,7 +215,9 @@ def test_prop_factory(self): # Test that props go to the expected location on reset. for ref_key, obs_key in zip(REFERENCE_PROP_KEYS, PROP_OBSERVATION_KEYS): - np.testing.assert_array_equal(observation[ref_key], observation[obs_key]) + np.testing.assert_array_almost_equal( + observation[ref_key], observation[obs_key] + ) def test_ghost_prop(self): task = tracking.MultiClipMocapTracking( @@ -242,7 +244,7 @@ def test_ghost_prop(self): np.squeeze(observation[key]) for key in REFERENCE_PROP_KEYS) np.testing.assert_array_equal(np.array(ghost_pos), goal_pos + GHOST_OFFSET) - np.testing.assert_array_equal(ghost_quat, goal_quat) + np.testing.assert_array_almost_equal(ghost_quat, goal_quat) def test_disable_props(self): task = tracking.MultiClipMocapTracking(