Skip to content

Commit

Permalink
Made it easier to extend classes by improving constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
fishbotics committed Mar 2, 2024
1 parent fd06889 commit 2088ee8
Showing 1 changed file with 35 additions and 35 deletions.
70 changes: 35 additions & 35 deletions urchin/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def copy(self, prefix="", scale=None):
"""
if scale is None:
scale = 1.0
b = Box(
b = self.__class__(
size=self.size.copy() * scale,
)
return b
Expand Down Expand Up @@ -517,12 +517,12 @@ def copy(self, prefix="", scale=None):
raise ValueError(
"Cannot rescale cylinder geometry with asymmetry in x/y"
)
c = Cylinder(
c = self.__class__(
radius=self.radius * scale[0],
length=self.length * scale[2],
)
else:
c = Cylinder(
c = self.__class__(
radius=self.radius * scale,
length=self.length * scale,
)
Expand Down Expand Up @@ -588,7 +588,7 @@ def copy(self, prefix="", scale=None):
if scale[0] != scale[1] or scale[0] != scale[2]:
raise ValueError("Spheres do not support non-uniform scaling!")
scale = scale[0]
s = Sphere(
s = self.__class__(
radius=self.radius * scale,
)
return s
Expand Down Expand Up @@ -706,7 +706,7 @@ def _from_xml(cls, node, path, lazy_load_meshes):
kwargs["meshes"] = meshes
kwargs["combine"] = combine

return Mesh(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
# Get the filename
Expand Down Expand Up @@ -735,7 +735,7 @@ def copy(self, prefix="", scale=None):
Returns
-------
:class:`.Sphere`
:class:`.Mesh`
A deep copy.
"""
meshes = [m.copy() for m in self.meshes]
Expand All @@ -749,7 +749,7 @@ def copy(self, prefix="", scale=None):
meshes[i] = m.apply_transform(sm)
base, fn = os.path.split(self.filename)
fn = "{}{}".format(prefix, self.filename)
m = Mesh(
m = self.__class__(
filename=os.path.join(base, fn),
combine=self.combine,
scale=(self.scale.copy() if self.scale is not None else None),
Expand Down Expand Up @@ -872,7 +872,7 @@ def copy(self, prefix="", scale=None):
:class:`.Geometry`
A deep copy.
"""
v = Geometry(
v = self.__class__(
box=(self.box.copy(prefix=prefix, scale=scale) if self.box else None),
cylinder=(
self.cylinder.copy(prefix=prefix, scale=scale)
Expand Down Expand Up @@ -941,7 +941,7 @@ def _from_xml(cls, node, path):
fn = get_filename(path, kwargs["filename"])
kwargs["image"] = PIL.Image.open(fn)

return Texture(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
# Save the image
Expand All @@ -963,7 +963,7 @@ def copy(self, prefix="", scale=None):
:class:`.Texture`
A deep copy.
"""
v = Texture(filename=self.filename, image=self.image.copy())
v = self.__class__(filename=self.filename, image=self.image.copy())
return v


Expand Down Expand Up @@ -1041,7 +1041,7 @@ def _from_xml(cls, node, path):
color = np.fromstring(color.attrib["rgba"], sep=" ", dtype=np.float64)
kwargs["color"] = color

return Material(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
# Simplify materials by collecting them at the top level.
Expand Down Expand Up @@ -1073,7 +1073,7 @@ def copy(self, prefix="", scale=None):
:class:`.Material`
A deep copy of the material.
"""
return Material(
return self.__class__(
name="{}{}".format(prefix, self.name),
color=self.color,
texture=self.texture,
Expand Down Expand Up @@ -1140,7 +1140,7 @@ def origin(self, value):
def _from_xml(cls, node, path, lazy_load_meshes):
kwargs = cls._parse(node, path, lazy_load_meshes)
kwargs["origin"] = parse_origin(node)
return Collision(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand All @@ -1165,7 +1165,7 @@ def copy(self, prefix="", scale=None):
if not isinstance(scale, (list, np.ndarray)):
scale = np.repeat(scale, 3)
origin[:3, 3] *= scale
return Collision(
return self.__class__(
name="{}{}".format(prefix, self.name),
origin=origin,
geometry=self.geometry.copy(prefix=prefix, scale=scale),
Expand Down Expand Up @@ -1248,7 +1248,7 @@ def material(self, value):
def _from_xml(cls, node, path, lazy_load_meshes):
kwargs = cls._parse(node, path, lazy_load_meshes)
kwargs["origin"] = parse_origin(node)
return Visual(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand All @@ -1273,7 +1273,7 @@ def copy(self, prefix="", scale=None):
if not isinstance(scale, (list, np.ndarray)):
scale = np.repeat(scale, 3)
origin[:3, 3] *= scale
return Visual(
return self.__class__(
geometry=self.geometry.copy(prefix=prefix, scale=scale),
name="{}{}".format(prefix, self.name),
origin=origin,
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def _from_xml(cls, node, path):
yz = float(n.attrib["iyz"])
zz = float(n.attrib["izz"])
inertia = np.array([[xx, xy, xz], [xy, yy, yz], [xz, yz, zz]], dtype=np.float64)
return Inertial(mass=mass, inertia=inertia, origin=origin)
return cls(mass=mass, inertia=inertia, origin=origin)

def _to_xml(self, parent, path):
node = ET.Element("inertial")
Expand Down Expand Up @@ -1381,7 +1381,7 @@ def copy(self, prefix="", mass=None, origin=None, inertia=None):
origin = self.origin.copy()
if inertia is None:
inertia = self.inertia.copy()
return Inertial(
return self.__class__(
mass=mass,
inertia=inertia,
origin=origin,
Expand Down Expand Up @@ -1448,7 +1448,7 @@ def copy(self, prefix="", scale=None):
:class:`.JointCalibration`
A deep copy of the visual.
"""
return JointCalibration(
return self.__class__(
rising=self.rising,
falling=self.falling,
)
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def copy(self, prefix="", scale=None):
:class:`.JointDynamics`
A deep copy of the visual.
"""
return JointDynamics(
return self.__class__(
damping=self.damping,
friction=self.friction,
)
Expand Down Expand Up @@ -1601,7 +1601,7 @@ def copy(self, prefix="", scale=None):
:class:`.JointLimit`
A deep copy of the visual.
"""
return JointLimit(
return self.__class__(
effort=self.effort,
velocity=self.velocity,
lower=self.lower,
Expand Down Expand Up @@ -1686,7 +1686,7 @@ def copy(self, prefix="", scale=None):
:class:`.JointMimic`
A deep copy of the joint mimic.
"""
return JointMimic(
return self.__class__(
joint="{}{}".format(prefix, self.joint),
multiplier=self.multiplier,
offset=self.offset,
Expand Down Expand Up @@ -1789,7 +1789,7 @@ def copy(self, prefix="", scale=None):
:class:`.SafetyController`
A deep copy of the visual.
"""
return SafetyController(
return self.__class__(
k_velocity=self.k_velocity,
k_position=self.k_position,
soft_lower_limit=self.soft_lower_limit,
Expand Down Expand Up @@ -1872,7 +1872,7 @@ def _from_xml(cls, node, path):
if len(hi) > 0:
hi = [h.text for h in hi]
kwargs["hardwareInterfaces"] = hi
return Actuator(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand Down Expand Up @@ -1900,7 +1900,7 @@ def copy(self, prefix="", scale=None):
:class:`.Actuator`
A deep copy of the visual.
"""
return Actuator(
return self.__class__(
name="{}{}".format(prefix, self.name),
mechanicalReduction=self.mechanicalReduction,
hardwareInterfaces=self.hardwareInterfaces.copy(),
Expand Down Expand Up @@ -1958,7 +1958,7 @@ def _from_xml(cls, node, path):
if len(hi) > 0:
hi = [h.text for h in hi]
kwargs["hardwareInterfaces"] = hi
return TransmissionJoint(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand All @@ -1982,7 +1982,7 @@ def copy(self, prefix="", scale=None):
:class:`.TransmissionJoint`
A deep copy.
"""
return TransmissionJoint(
return self.__class__(
name="{}{}".format(prefix, self.name),
hardwareInterfaces=self.hardwareInterfaces.copy(),
)
Expand Down Expand Up @@ -2083,7 +2083,7 @@ def _from_xml(cls, node, path):
if trans_type is None:
trans_type = node.find("type").text
kwargs["trans_type"] = trans_type
return Transmission(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand All @@ -2105,7 +2105,7 @@ def copy(self, prefix="", scale=None):
:class:`.Transmission`
A deep copy.
"""
return Transmission(
return self.__class__(
name="{}{}".format(prefix, self.name),
trans_type=self.trans_type,
joints=[j.copy(prefix) for j in self.joints],
Expand Down Expand Up @@ -2476,7 +2476,7 @@ def _from_xml(cls, node, path):
axis = np.fromstring(axis.attrib["xyz"], sep=" ")
kwargs["axis"] = axis
kwargs["origin"] = parse_origin(node)
return Joint(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand Down Expand Up @@ -2553,7 +2553,7 @@ def copy(self, prefix="", scale=None):
if not isinstance(scale, (list, np.ndarray)):
scale = np.repeat(scale, 3)
origin[:3, 3] *= scale
cpy = Joint(
cpy = self.__class__(
name="{}{}".format(prefix, self.name),
joint_type=self.joint_type,
parent="{}{}".format(prefix, self.parent),
Expand Down Expand Up @@ -2720,7 +2720,7 @@ def copy(self, prefix="", scale=None, collision_only=False):
if not collision_only:
visuals = [v.copy(prefix=prefix, scale=scale) for v in self.visuals]

cpy = Link(
cpy = self.__class__(
name="{}{}".format(prefix, self.name),
inertial=inertial,
visuals=visuals,
Expand Down Expand Up @@ -3728,7 +3728,7 @@ def copy(self, name=None, prefix="", scale=None, collision_only=False):
copy : :class:`.URDF`
The copied URDF.
"""
return URDF(
return self.__class__(
name=(name if name else self.name),
links=[v.copy(prefix, scale, collision_only) for v in self.links],
joints=[v.copy(prefix, scale) for v in self.joints],
Expand Down Expand Up @@ -3818,7 +3818,7 @@ def join(self, other, link, origin=None, name=None, prefix=""):
)
)

return URDF(
return self.__class__(
name=name,
links=links,
joints=joints,
Expand Down Expand Up @@ -4091,7 +4091,7 @@ def _from_xml(cls, node, path, lazy_load_meshes):

data = ET.tostring(extra_xml_node)
kwargs["other_xml"] = data
return URDF(**kwargs)
return cls(**kwargs)

def _to_xml(self, parent, path):
node = self._unparse(path)
Expand Down

0 comments on commit 2088ee8

Please sign in to comment.