Skip to content

Commit

Permalink
Merge pull request mosdef-hub#1203 from chrisjonesBSU/remove-rigid
Browse files Browse the repository at this point in the history
Remove rigid body data structure from `Compound`
  • Loading branch information
chrisjonesBSU authored Oct 30, 2024
2 parents aabbab6 + a4c41c5 commit ba89d6c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 945 deletions.
272 changes: 3 additions & 269 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,10 @@ class Compound(object):
compound is the root of the containment hierarchy.
referrers : set
Other compounds that reference this part with labels.
rigid_id : int, default=None
The ID of the rigid body that this Compound belongs to. Only Particles
(the bottom of the containment hierarchy) can have integer values for
`rigid_id`. Compounds containing rigid particles will always have
`rigid_id == None`. See also `contains_rigid`.
boundingbox : mb.Box
The bounds (xmin, xmax, ymin, ymax, zmin, zmax) of particles in Compound
center
contains_rigid
mass
max_rigid_id
n_particles
n_bonds
root
Expand Down Expand Up @@ -183,10 +176,6 @@ def __init__(

self.port_particle = port_particle

self._rigid_id = None
self._contains_rigid = False
self._check_if_contains_rigid_bodies = False

self.element = element
if mass and float(mass) < 0.0:
raise ValueError("Cannot set a Compound mass value less than zero")
Expand Down Expand Up @@ -583,230 +572,6 @@ def charge(self, value):
"not at the bottom of the containment hierarchy."
)

@property
def rigid_id(self):
"""Get the rigid_id of the Compound."""
return self._rigid_id

@rigid_id.setter
def rigid_id(self, value):
if self._contains_only_ports():
self._rigid_id = value
for ancestor in self.ancestors():
ancestor._check_if_contains_rigid_bodies = True
else:
raise AttributeError(
"rigid_id is immutable for Compounds that are "
"not at the bottom of the containment hierarchy."
)

@property
def contains_rigid(self):
"""Return True if the Compound contains rigid bodies.
If the Compound contains any particle with a rigid_id != None
then contains_rigid will return True. If the Compound has no
children (i.e. the Compound resides at the bottom of the containment
hierarchy) then contains_rigid will return False.
Returns
-------
bool,
True if the Compound contains any particle with a rigid_id != None
Notes
-----
The private variable '_check_if_contains_rigid_bodies' is used to help
cache the status of 'contains_rigid'.
If '_check_if_contains_rigid_bodies' is False, then the rigid body
containment of the Compound has not changed, and the particle tree is
not traversed, boosting performance.
"""
if self._check_if_contains_rigid_bodies:
self._check_if_contains_rigid_bodies = False
if any(p.rigid_id is not None for p in self._particles()):
self._contains_rigid = True
else:
self._contains_rigid = False
return self._contains_rigid

@property
def max_rigid_id(self):
"""Return the maximum rigid body ID contained in the Compound.
This is usually used by compound.root to determine the maximum
rigid_id in the containment hierarchy.
Returns
-------
int or None
The maximum rigid body ID contained in the Compound. If no
rigid body IDs are found, None is returned
"""
try:
return max(
[p.rigid_id for p in self.particles() if p.rigid_id is not None]
)
except ValueError:
return

def rigid_particles(self, rigid_id=None):
"""Generate all particles in rigid bodies.
If a rigid_id is specified, then this function will only yield particles
with a matching rigid_id.
Parameters
----------
rigid_id : int, optional
Include only particles with this rigid body ID
Yields
------
mb.Compound
The next particle with a rigid_id that is not None, or the next
particle with a matching rigid_id if specified
"""
for particle in self.particles():
if rigid_id is not None:
if particle.rigid_id == rigid_id:
yield particle
else:
if particle.rigid_id is not None:
yield particle

def label_rigid_bodies(self, discrete_bodies=None, rigid_particles=None):
"""Designate which Compounds should be treated as rigid bodies.
If no arguments are provided, this function will treat the compound
as a single rigid body by providing all particles in `self` with the
same rigid_id. If `discrete_bodies` is not None, each instance of
a Compound with a name found in `discrete_bodies` will be treated as a
unique rigid body. If `rigid_particles` is not None, only Particles
(Compounds at the bottom of the containment hierarchy) matching this
name will be considered part of the rigid body.
Parameters
----------
discrete_bodies : str or list of str, optional, default=None
Name(s) of Compound instances to be treated as unique rigid bodies.
Compound instances matching this (these) name(s) will be provided
with unique rigid_ids
rigid_particles : str or list of str, optional, default=None
Name(s) of Compound instances at the bottom of the containment
hierarchy (Particles) to be included in rigid bodies. Only Particles
matching this (these) name(s) will have their rigid_ids altered to
match the rigid body number.
Examples
--------
Creating a rigid benzene
>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.label_rigid_bodies()
Creating a semi-rigid benzene, where only the carbons are treated as
a rigid body
>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.label_rigid_bodies(rigid_particles='C')
Create a box of rigid benzenes, where each benzene has a unique rigid
body ID.
>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.name = 'Benzene'
>>> filled = mb.fill_box(benzene,
... n_compounds=10,
... box=[0, 0, 0, 4, 4, 4])
>>> filled.label_rigid_bodies(distinct_bodies='Benzene')
Create a box of semi-rigid benzenes, where each benzene has a unique
rigid body ID and only the carbon portion is treated as rigid.
>>> import mbuild as mb
>>> from mbuild.utils.io import get_fn
>>> benzene = mb.load(get_fn('benzene.mol2'))
>>> benzene.name = 'Benzene'
>>> filled = mb.fill_box(benzene,
... n_compounds=10,
... box=[0, 0, 0, 4, 4, 4])
>>> filled.label_rigid_bodies(distinct_bodies='Benzene',
... rigid_particles='C')
"""
if discrete_bodies is not None:
if isinstance(discrete_bodies, str):
discrete_bodies = [discrete_bodies]
if rigid_particles is not None:
if isinstance(rigid_particles, str):
rigid_particles = [rigid_particles]

if self.root.max_rigid_id is not None:
rigid_id = self.root.max_rigid_id + 1
warn(
f"{rigid_id} rigid bodies already exist. Incrementing 'rigid_id'"
f"starting from {rigid_id}."
)
else:
rigid_id = 0

for successor in self.successors():
if discrete_bodies and successor.name not in discrete_bodies:
continue
for particle in successor.particles():
if rigid_particles and particle.name not in rigid_particles:
continue
particle.rigid_id = rigid_id
if discrete_bodies:
rigid_id += 1

def unlabel_rigid_bodies(self):
"""Remove all rigid body labels from the Compound."""
self._check_if_contains_rigid_bodies = True
for child in self.children:
child._check_if_contains_rigid_bodies = True
for particle in self.particles():
particle.rigid_id = None

def _increment_rigid_ids(self, increment):
"""Increment the rigid_id of all rigid Particles in a Compound.
Adds `increment` to the rigid_id of all Particles in `self` that
already have an integer rigid_id.
"""
for particle in self.particles():
if particle.rigid_id is not None:
particle.rigid_id += increment

def _reorder_rigid_ids(self):
"""Reorder rigid body IDs ensuring consecutiveness.
Primarily used internally to ensure consecutive rigid_ids following
removal of a Compound.
"""
max_rigid = self.max_rigid_id
unique_rigid_ids = sorted(
set([p.rigid_id for p in self.rigid_particles()])
)
n_unique_rigid = len(unique_rigid_ids)
if max_rigid and n_unique_rigid != max_rigid + 1:
missing_rigid_id = (
unique_rigid_ids[-1] * (unique_rigid_ids[-1] + 1)
) / 2 - sum(unique_rigid_ids)
for successor in self.successors():
if successor.rigid_id is not None:
if successor.rigid_id > missing_rigid_id:
successor.rigid_id -= 1
if self.rigid_id:
if self.rigid_id > missing_rigid_id:
self.rigid_id -= 1

def add(
self,
new_child,
Expand All @@ -815,7 +580,6 @@ def add(
replace=False,
inherit_periodicity=None,
inherit_box=False,
reset_rigid_ids=True,
check_box_size=True,
):
"""Add a part to the Compound.
Expand All @@ -840,11 +604,6 @@ def add(
Compound being added
inherit_box: bool, optional, default=False
Replace the box of self with the box of the Compound being added
reset_rigid_ids : bool, optional, default=True
If the Compound to be added contains rigid bodies, reset the
rigid_ids such that values remain distinct from rigid_ids
already present in `self`. Can be set to False if attempting
to add Compounds to an existing rigid body.
check_box_size : bool, optional, default=True
Checks and warns if compound box is smaller than its bounding box after adding new_child.
"""
Expand Down Expand Up @@ -894,15 +653,10 @@ def add(
self.add(
child,
label=label_list[i],
reset_rigid_ids=reset_rigid_ids,
check_box_size=False,
)
else:
self.add(
child,
reset_rigid_ids=reset_rigid_ids,
check_box_size=False,
)
self.add(child, check_box_size=False)

return

Expand All @@ -919,13 +673,6 @@ def add(
)
self._mass = 0

if new_child.contains_rigid or new_child.rigid_id is not None:
if self.contains_rigid and reset_rigid_ids:
new_child._increment_rigid_ids(increment=self.max_rigid_id + 1)
self._check_if_contains_rigid_bodies = True
if self.rigid_id is not None:
self.rigid_id = None

# Create children and labels on the first add operation
if self.children is None:
self.children = list()
Expand Down Expand Up @@ -1067,7 +814,7 @@ def _check_if_empty(child):
for particle in particles_to_remove:
_check_if_empty(particle)

# Fix rigid_ids and remove obj from bondgraph
# Remove obj from bondgraph
for removed_part in to_remove:
self._remove(removed_part)

Expand All @@ -1077,11 +824,6 @@ def _check_if_empty(child):
removed_part.parent.children.remove(removed_part)
self._remove_references(removed_part)

# Check and reorder rigid id
for _ in particles_to_remove:
if self.contains_rigid:
self.root._reorder_rigid_ids()

# Remove ghost ports
self._prune_ghost_ports()

Expand Down Expand Up @@ -1148,10 +890,7 @@ def _prune_ghost_ports(self):
self._remove_references(port)

def _remove(self, removed_part):
"""Worker for remove(). Fixes rigid IDs and removes bonds."""
if removed_part.rigid_id is not None:
for ancestor in removed_part.ancestors():
ancestor._check_if_contains_rigid_bodies = True
"""Worker for remove(). Removes bonds."""
if self.root.bond_graph.has_node(removed_part):
for neighbor in nx.neighbors(
self.root.bond_graph.copy(), removed_part
Expand Down Expand Up @@ -3600,12 +3339,7 @@ def _clone(self, clone_of=None, root_container=None):
newone._pos = deepcopy(self._pos)
newone.port_particle = deepcopy(self.port_particle)
newone._box = deepcopy(self._box)
newone._check_if_contains_rigid_bodies = deepcopy(
self._check_if_contains_rigid_bodies
)
newone._periodicity = deepcopy(self._periodicity)
newone._contains_rigid = deepcopy(self._contains_rigid)
newone._rigid_id = deepcopy(self._rigid_id)
newone._charge = deepcopy(self._charge)
newone._mass = deepcopy(self._mass)
if hasattr(self, "index"):
Expand Down
Loading

0 comments on commit ba89d6c

Please sign in to comment.