Skip to content

Commit

Permalink
[Multiverse] create object from generic_description.
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Oct 22, 2024
1 parent 749dc59 commit f45909d
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/pycram/datastructures/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ def visual_geometry_type(self) -> Shape:
return Shape.PLANE


VisualShapeUnion = Union[BoxVisualShape, SphereVisualShape, CapsuleVisualShape, CylinderVisualShape, MeshVisualShape,
PlaneVisualShape]


@dataclass
class State(ABC):
"""
Expand Down
80 changes: 71 additions & 9 deletions src/pycram/object_descriptors/mjcf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
import pathlib
from xml.etree import ElementTree as ET

import numpy as np
from dm_control import mjcf
from geometry_msgs.msg import Point
from typing_extensions import Union, List, Optional, Dict, Tuple
from xml.etree import ElementTree as ET

from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \
SphereVisualShape, MeshVisualShape
from ..datastructures.enums import JointType, MJCFGeomType, MJCFJointType
SphereVisualShape, MeshVisualShape, VisualShapeUnion
from ..datastructures.enums import JointType, MJCFGeomType, MJCFJointType, Shape
from ..datastructures.pose import Pose
from ..description import JointDescription as AbstractJointDescription, \
LinkDescription as AbstractLinkDescription, ObjectDescription as AbstractObjectDescription
Expand Down Expand Up @@ -74,7 +74,6 @@ def name(self) -> str:


class JointDescription(AbstractJointDescription):

mjcf_type_map = {
MJCFJointType.HINGE.value: JointType.REVOLUTE,
MJCFJointType.BALL.value: JointType.SPHERICAL,
Expand Down Expand Up @@ -176,15 +175,18 @@ class ObjectFactory(Factory):
"""
Create MJCF object descriptions from mesh files.
"""
def __init__(self, object_name: str, file_path: str, config: Configuration, texture_type: str = "png"):

def __init__(self, file_path: str, config: Configuration):
super().__init__(file_path, config)

def from_mesh_file(self, object_name: str, texture_type: str = "png"):

self._world_builder = WorldBuilder(usd_file_path=self.tmp_usd_file_path)

body_builder = self._world_builder.add_body(body_name=object_name)

tmp_usd_mesh_file_path, tmp_origin_mesh_file_path = self.import_mesh(
mesh_file_path=file_path, merge_mesh=True)
mesh_file_path=self.source_file_path, merge_mesh=True)
mesh_stage = Usd.Stage.Open(tmp_usd_mesh_file_path)
for idx, mesh_prim in enumerate([prim for prim in mesh_stage.Traverse() if prim.IsA(UsdGeom.Mesh)]):
mesh_name = mesh_prim.GetName()
Expand All @@ -200,10 +202,11 @@ def __init__(self, object_name: str, file_path: str, config: Configuration, text
geom_builder.add_mesh(mesh_name=mesh_name, mesh_property=mesh_property)

# Add texture if available
texture_file_path = file_path.replace(pathlib.Path(file_path).suffix, f".{texture_type}")
texture_file_path = self.source_file_path.replace(pathlib.Path(self.source_file_path).suffix,
f".{texture_type}")
if pathlib.Path(texture_file_path).exists():
self.add_material_with_texture(geom_builder=geom_builder, material_name=f"M_{object_name}_{idx}",
texture_file_path=texture_file_path)
texture_file_path=texture_file_path)

geom_builder.build()

Expand Down Expand Up @@ -236,6 +239,64 @@ def export_to_mjcf(self, output_file_path: str):
exporter.export(keep_usd=False)


class PrimitiveObjectFactory(ObjectFactory):

def __init__(self, object_name: str, shape_data: VisualShapeUnion, save_path: str,
orientation: Optional[List[float]] = None):
"""
Create an MJCF object description from a primitive shape.
:param object_name: The name of the object.
:param shape_data: The shape data of the object.
:param save_path: The path to save the MJCF file.
:param orientation: The orientation of the object.
"""
self.shape_data: VisualShapeUnion = shape_data
self.orientation: List[float] = [0, 0, 0, 1] if orientation is None else orientation
config = Configuration(model_name=object_name,
fixed_base=False,
with_visual=True,
with_collision=True,
default_rgba=np.array(shape_data.rgba_color.get_rgba()))
super().__init__(save_path, config)

def build_shape(self):

self._world_builder = WorldBuilder(usd_file_path=self.tmp_usd_file_path)

body_builder = self._world_builder.add_body(body_name=self.config.model_name)

geom_type_map = {
Shape.SPHERE: GeomType.SPHERE,
Shape.BOX: GeomType.CUBE,
Shape.CYLINDER: GeomType.CYLINDER,
Shape.PLANE: GeomType.PLANE,
Shape.CAPSULE: GeomType.CAPSULE,
}
geom_property = GeomProperty(geom_type=geom_type_map[self.shape_data.visual_geometry_type],
is_visible=self.config.with_visual,
is_collidable=self.config.with_collision,
rgba=self.config.default_rgba)

geom_builder = body_builder.add_geom(
geom_name=f"{self.config.model_name}_Shape",
geom_property=geom_property
)
geom_pos = self.shape_data.visual_frame_position
geom_quat = np.array(self.orientation)
if self.shape_data.visual_geometry_type == Shape.PLANE:
geom_builder.set_transform(pos=geom_pos, quat=geom_quat, scale=np.array([50, 50, 1]))
elif self.shape_data.visual_geometry_type == Shape.BOX:
geom_builder.set_transform(pos=geom_pos, quat=geom_quat, scale=np.array(self.shape_data.half_extents) * 2)
elif self.shape_data.visual_geometry_type == Shape.SPHERE:
geom_builder.set_transform(pos=geom_pos, quat=geom_quat)
geom_builder.set_attribute(radius=self.shape_data.radius)
elif self.shape_data.visual_geometry_type in [Shape.CYLINDER, Shape.CAPSULE]:
geom_builder.set_transform(pos=geom_pos, quat=geom_quat)
geom_builder.set_attribute(radius=self.shape_data.radius, height=self.shape_data.length)
geom_builder.build()


class ObjectDescription(AbstractObjectDescription):
"""
A class that represents an object description of an object.
Expand Down Expand Up @@ -373,10 +434,11 @@ def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] =
:param color: The color of the object.
:param save_path: The path to save the generated xml file.
"""
factory = ObjectFactory(object_name=name, file_path=path,
factory = ObjectFactory(file_path=path,
config=Configuration(model_name=name,
fixed_base=False,
default_rgba=np.array(color.get_rgba())))
factory.from_mesh_file(object_name=name)
factory.export_to_mjcf(output_file_path=save_path)

def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None:
Expand Down
12 changes: 11 additions & 1 deletion src/pycram/worlds/multiverse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from time import sleep

import numpy as np
Expand All @@ -15,7 +16,8 @@
from ..datastructures.pose import Pose
from ..datastructures.world import World
from ..description import Link, Joint
from ..object_descriptors.mjcf import ObjectDescription as MJCF
from ..object_descriptors.mjcf import ObjectDescription as MJCF, ObjectFactory, PrimitiveObjectFactory
from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription
from ..robot_description import RobotDescription
from ..ros.logging import logwarn, logerr
from ..utils import RayTestUtils, wxyz_to_xyzw, xyzw_to_wxyz
Expand Down Expand Up @@ -154,6 +156,14 @@ def _spawn_floor(self):
self.floor = Object("floor", ObjectType.ENVIRONMENT, "plane.urdf",
world=self)

def load_generic_object_and_get_id(self, description: GenericObjectDescription,
pose: Optional[Pose] = None) -> int:
save_path = os.path.join(self.cache_manager.cache_dir, description.name + ".xml")
object_factory = PrimitiveObjectFactory(description.name, description.links[0].geometry, save_path)
object_factory.build_shape()
object_factory.export_to_mjcf(save_path)
return self.load_object_and_get_id(description.name, pose, ObjectType.GENERIC_OBJECT)

def get_images_for_target(self, target_pose: Pose,
cam_pose: Pose,
size: int = 256,
Expand Down
13 changes: 12 additions & 1 deletion test/test_multiverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from tf.transformations import quaternion_from_euler, quaternion_multiply
from typing_extensions import Optional, List

from pycram.datastructures.dataclasses import ContactPointsList, ContactPoint, AxisAlignedBoundingBox
from pycram.datastructures.dataclasses import ContactPointsList, ContactPoint, AxisAlignedBoundingBox, Color
from pycram.datastructures.enums import ObjectType, Arms, JointType
from pycram.datastructures.pose import Pose
from pycram.robot_description import RobotDescriptionManager
from pycram.world_concepts.world_object import Object
from pycram.validation.error_checkers import calculate_angle_between_quaternions
from pycram.helper import get_robot_mjcf_path, parse_mjcf_actuators
from pycram.object_descriptors.generic import ObjectDescription as GenericObjectDescription

multiverse_installed = True
try:
Expand Down Expand Up @@ -53,6 +54,16 @@ def tearDownClass(cls):
def tearDown(self):
self.multiverse.remove_all_objects()

def test_load_generic_object(self):
obj_desc = GenericObjectDescription('test_cube', [0, 0, 0], [0.1, 0.1, 0.1],
color=Color(1, 0, 0, 1))
obj = Object(obj_desc.name, ObjectType.GENERIC_OBJECT, description=obj_desc)
self.assertIsInstance(obj, Object)
self.assertTrue(obj in self.multiverse.objects)
obj.set_position([1, 1, 0.1])
pose = obj.get_pose()
self.assert_poses_are_equal(pose, Pose([1, 1, 0.1], [0, 0, 0, 1]))

def test_save_and_restore_state(self):
milk = self.spawn_milk([1, 1, 0.1])
robot = self.spawn_robot()
Expand Down

0 comments on commit f45909d

Please sign in to comment.