From e59b9dc1aa04b463319c3f025d0a173fd3a9f751 Mon Sep 17 00:00:00 2001 From: Aleksander Durumeric Date: Thu, 21 Mar 2024 10:18:01 -0500 Subject: [PATCH] Added method for using only noise-derived forces. stagedjslicegauss_map makes a augmenting map similar to previous methods, but only uses coord information from the input trajectory. Modified Trajectory definitions along the way: added coord-only trajectory, and changed coord/force trajectory names. --- src/aggforce/__init__.py | 2 +- src/aggforce/map/__init__.py | 10 +- src/aggforce/map/core.py | 84 +++++++++++++-- src/aggforce/map/jaxlinearmap.py | 115 +++++++++++++++----- src/aggforce/map/tmap.py | 70 +++++++++++- src/aggforce/qp/__init__.py | 2 +- src/aggforce/qp/basicagg.py | 4 +- src/aggforce/qp/jgauss.py | 159 ++++++++++++++++++++++++++-- src/aggforce/qp/qplinear.py | 4 +- src/aggforce/trajectory/__init__.py | 7 +- src/aggforce/trajectory/core.py | 91 +++++++++++++--- 11 files changed, 480 insertions(+), 68 deletions(-) diff --git a/src/aggforce/__init__.py b/src/aggforce/__init__.py index 5247e70..2d05050 100644 --- a/src/aggforce/__init__.py +++ b/src/aggforce/__init__.py @@ -22,6 +22,6 @@ # in case jax is not installed try: - from .qp import joptgauss_map, stagedjoptgauss_map + from .qp import joptgauss_map, stagedjoptgauss_map, stagedjslicegauss_map except ImportError: pass diff --git a/src/aggforce/map/__init__.py b/src/aggforce/map/__init__.py index 03bde85..c042994 100644 --- a/src/aggforce/map/__init__.py +++ b/src/aggforce/map/__init__.py @@ -2,7 +2,15 @@ # __init__ doesn't use the imported objects # ruff: noqa: F401 from .core import LinearMap, CLAMap, trjdot -from .tmap import TMap, SeperableTMap, CLAFTMap, AugmentedTMap, ComposedTMap, RATMap +from .tmap import ( + TMap, + SeperableTMap, + CLAFTMap, + AugmentedTMap, + ComposedTMap, + NullForcesTMap, + RATMap, +) from .tools import lmap_augvariables, smear_map # in case jax is not installed diff --git a/src/aggforce/map/core.py b/src/aggforce/map/core.py index 3e34b85..084c213 100644 --- a/src/aggforce/map/core.py +++ b/src/aggforce/map/core.py @@ -4,11 +4,18 @@ positions or forces. """ -from typing import Union, List, Callable, Final, Dict, Optional +from typing import Union, List, Callable, Final, Dict, Optional, Literal import numpy as np from ..util import trjdot +# this should be very fast +def _has_nans(x: np.ndarray) -> bool: + flat = x.ravel(order="K") + # this will be a 1 element array + return bool(np.isnan(np.dot(flat, flat))) + + # _Taggable is only used by CLAMap right now, but its a separate class # to keep things clear. class _Taggable: @@ -60,6 +67,8 @@ def __init__( self, mapping: Union[List[List[int]], np.ndarray], n_fg_sites: Union[int, None] = None, + handle_nans: Union[bool, Literal["safe"]] = True, + nan_check_threshold: float = 1e-6, ) -> None: r"""Initialize LinearMapping from something describing a map. @@ -78,8 +87,19 @@ def __init__( Certain mapping descriptions make it ambiguous how many total fine-grained sites there are. This variable allows this ambiguity to be resolved. - tags (dictionary or None): - Passed to Map init. + handle_nans: + If true, np.nans in the matrices given as input to class calls are treated + in a special way: + 1) They are converted to np.inf + 2) If any inf values exist in the output, an exception is raised + 3) All nans in the output are set to 0 + 0*inf = nan, but all other numbers satisfy c*inf=+-inf; as a result, this + procedure allows input arrays that have nans to be operated on such that + if 0 is multiplied with that entry, 0 is returned. + If safe, this is also done, but we make sure that no temporary modifications + are performed in the input matrix (else, we may temporarily in-place set + Nan to Inf). If False, simple matrix multiplication is performed without + NaN specific logic. Example: ------- @@ -122,6 +142,15 @@ def __init__( else: raise ValueError("Cannot understanding mapping f{mapping}.") + self.handle_nans = handle_nans + if self.handle_nans: + if not np.all(np.isfinite(self.standard_matrix)): + raise ValueError( + "Nan checking can only be performed in " + "standard_matrix is itself finite." + ) + self.nan_check_threshold = nan_check_threshold + @property def standard_matrix(self) -> np.ndarray: r"""The mapping in standard matrix format.""" @@ -161,14 +190,37 @@ def __call__( Arguments: --------- points (np.ndarray): - Assumed to be 3 dimensional of shape (n_steps,n_sites,n_dims). + Assumed to be 3 dimensional of shape (n_steps,n_sites,n_dims). Note that + if self.handle_nans is True, this array may be temporarily altered if + it contains NaN values. Returns: ------- Combines points along the n_sites dimension according to the internal - map. + map. Note that the handling of NaNs depends on initialization options. """ - return trjdot(points, self.standard_matrix) + nan_handling = self.handle_nans and _has_nans(points) + if nan_handling: + input_mask = np.isnan(points) + if self.handle_nans == "safe": + input_matrix = points.copy() + else: + input_matrix = points + input_matrix[input_mask] = 0.0 + raw_result = trjdot(input_matrix, self.standard_matrix) + input_matrix[input_mask] = -1.0 + pushed_result = trjdot(input_matrix, self.standard_matrix) + if not np.allclose( + raw_result, pushed_result, atol=self.nan_check_threshold + ): + raise ValueError( + "NaN handling is on and results seem to depend on NaN " + "positions in input array. Check input and standard_matrix." + ) + input_matrix[input_mask] = np.nan + return raw_result + else: + return trjdot(points, self.standard_matrix) def flat_call(self, flattened: np.ndarray) -> np.ndarray: """Apply map to pre-flattened array. @@ -204,19 +256,27 @@ def flat_call(self, flattened: np.ndarray) -> np.ndarray: @property def T(self) -> "LinearMap": """LinearMap defined by transpose of its standard matrix.""" - return LinearMap(mapping=self.standard_matrix.T) + return LinearMap(mapping=self.standard_matrix.T, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __matmul__(self, lm: "LinearMap", /) -> "LinearMap": """LinearMap defined by multiplying the standard_matrix's of arguments.""" - return LinearMap(mapping=self.standard_matrix @ lm.standard_matrix) + return LinearMap(mapping=self.standard_matrix @ lm.standard_matrix, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __rmul__(self, c: float, /) -> "LinearMap": """LinearMap defined by multiplying the standard_matrix's with a coefficient.""" - return LinearMap(mapping=c * self.standard_matrix) + return LinearMap(mapping=c * self.standard_matrix, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __add__(self, lm: "LinearMap", /) -> "LinearMap": """LinearMap defined by adding standard_matrices.""" - return LinearMap(mapping=self.standard_matrix + lm.standard_matrix) + return LinearMap(mapping=self.standard_matrix + lm.standard_matrix, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def astype(self, *args, **kwargs) -> "LinearMap": """Convert to a given precision as determined by arguments. @@ -225,7 +285,9 @@ def astype(self, *args, **kwargs) -> "LinearMap": instance. Arguments are passed to np astype. Setting copy to False may reduce copies, but may return instances with shared references. """ - return self.__class__(mapping=self.standard_matrix.astype(*args, **kwargs)) + return self.__class__(mapping=self.standard_matrix.astype(*args, **kwargs), + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) class CLAMap(_Taggable): diff --git a/src/aggforce/map/jaxlinearmap.py b/src/aggforce/map/jaxlinearmap.py index 69ecb99..09d70c7 100644 --- a/src/aggforce/map/jaxlinearmap.py +++ b/src/aggforce/map/jaxlinearmap.py @@ -1,6 +1,7 @@ """Extends LinearMaps for Jax operations.""" -from typing import overload, TypeVar -from jax import Array +from typing import overload, TypeVar, Tuple, Union +from functools import partial +from jax import Array, jit import jax.numpy as jnp from numpy.typing import NDArray import numpy as np @@ -10,20 +11,60 @@ ArrT = TypeVar("ArrT", NDArray, Array) +@partial(jit, static_argnames="nan_handling") +def _trjdot_worker( + factor: Array, points: Array, nan_handling: bool +) -> Tuple[Array, Array]: + """Help apply internal trjdot transforms. + + If nan_handling is false, applies trjdot and returns + a tuple with both entries the same result. If true, the first + entry in the tuple is the result of setting nans to 0, and the second + result is setting nans to 1. + """ + if nan_handling: + input_matrix_0 = jnp.nan_to_num( + points, + nan=0.0, + ) + input_matrix_1 = jnp.nan_to_num( + points, + nan=1.0, + ) + result_0 = jtrjdot(input_matrix_0, factor) + result_1 = jtrjdot(input_matrix_1, factor) + return (result_0, result_1) + else: + result = jtrjdot(points, factor) + return (result, result) + + class JLinearMap(LinearMap): """Extends LinearMaps to map Jax arrays.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, bypass_nan_check: bool = False, **kwargs) -> None: """Initialize. - All argments are passed via super(). + Arguments: + --------- + *args: + Passed via super(). + bypass_nan_check: + If true, we check to see if infs were generated when mapping matrices + with a nan check (similar to LinearMap behavior). If not, we do not; + this often must be set to false to be wrapped in a jit call. + **kwargs: + Passed via super(). + """ super().__init__(*args, **kwargs) + self.bypass_nan_check = bypass_nan_check + self._jax_standard_matrix = jnp.asarray(self.standard_matrix) @property def jax_standard_matrix(self) -> Array: """Return standard_matrix as a Jax array.""" - return jnp.asarray(self.standard_matrix) + return self._jax_standard_matrix @overload def __call__(self, points: NDArray) -> NDArray: @@ -33,7 +74,7 @@ def __call__(self, points: NDArray) -> NDArray: def __call__(self, points: Array) -> Array: ... - def __call__(self, points: ArrT) -> ArrT: + def __call__(self, points: Union[NDArray, Array]) -> Union[NDArray, Array]: r"""Apply map to a particular form of 3-dim array. Arguments: @@ -51,26 +92,32 @@ def __call__(self, points: ArrT) -> ArrT: Notes: ----- This implementation is effectively identical to that in the parent class, - but uses Jax operations. + but will behave differently if an invalid map is applied to """ if isinstance(points, np.ndarray): + numpy_input = True jpoints = jnp.asarray(points) else: + numpy_input = False jpoints = points - transformed = jtrjdot(jpoints, self.jax_standard_matrix) - if isinstance(points, np.ndarray): - return np.asarray(transformed) - else: - return transformed - @overload - def flat_call(self, flattened: NDArray) -> NDArray: - ... - - @overload - def flat_call(self, flattened: Array) -> Array: - ... + result, sec_result = _trjdot_worker( + factor=self.jax_standard_matrix, + points=jpoints, + nan_handling=self.handle_nans, + ) + if (not self.bypass_nan_check) and self.handle_nans: + if not jnp.allclose(result, sec_result, atol=self.nan_check_threshold): + raise ValueError( + "NaN handling is on and multiplication tried to use " + "a NaN value. Check the input array and " + "standard_matrix." + ) + if numpy_input: + return np.asarray(result) + else: + return result def flat_call(self, flattened: ArrT) -> ArrT: """Apply map to pre-flattened array. @@ -111,21 +158,39 @@ def flat_call(self, flattened: ArrT) -> ArrT: @property def T(self) -> "JLinearMap": """LinearMap defined by transpose of its standard matrix.""" - return JLinearMap(mapping=self.standard_matrix.T) + return JLinearMap(mapping=self.standard_matrix.T, + bypass_nan_check=self.bypass_nan_check, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __matmul__(self, lm: "LinearMap", /) -> "JLinearMap": """LinearMap defined by multiplying the standard_matrix's of arguments.""" - return JLinearMap(mapping=self.standard_matrix @ lm.standard_matrix) + return JLinearMap(mapping=self.standard_matrix @ lm.standard_matrix, + bypass_nan_check=self.bypass_nan_check, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __rmul__(self, c: float, /) -> "JLinearMap": """LinearMap defined by multiplying the standard_matrix's with a coefficient.""" - return JLinearMap(mapping=c * self.standard_matrix) + return JLinearMap(mapping=c * self.standard_matrix, + bypass_nan_check=self.bypass_nan_check, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) def __add__(self, lm: "LinearMap", /) -> "JLinearMap": """LinearMap defined by adding standard_matrices.""" - return JLinearMap(mapping=self.standard_matrix + lm.standard_matrix) + return JLinearMap(mapping=self.standard_matrix + lm.standard_matrix, + bypass_nan_check=self.bypass_nan_check, + handle_nans=self.handle_nans, + nan_check_threshold=self.nan_check_threshold) @classmethod - def from_linearmap(cls, lm: LinearMap, /) -> "JLinearMap": + def from_linearmap( + cls, lm: LinearMap, /, bypass_nan_check: bool = False + ) -> "JLinearMap": """Create JLinearMap from LinearMap.""" - return JLinearMap(mapping=lm.standard_matrix) + return JLinearMap( + mapping=lm.standard_matrix, + bypass_nan_check=bypass_nan_check, + handle_nans=lm.handle_nans, + ) diff --git a/src/aggforce/map/tmap.py b/src/aggforce/map/tmap.py index 71a765b..e95111d 100644 --- a/src/aggforce/map/tmap.py +++ b/src/aggforce/map/tmap.py @@ -3,10 +3,25 @@ These objects effectively map coordinates and forces together. """ -from typing import Tuple, Callable, Final, Iterable, TypeVar +from typing import ( + Tuple, + Callable, + Final, + Iterable, + TypeVar, + Optional, + Any, +) from abc import ABC, abstractmethod +from warnings import warn import numpy as np -from ..trajectory import Trajectory, AugmentedTrajectory, Augmenter +from ..trajectory import ( + CoordsTrajectory, + ForcesTrajectory, + Trajectory, + AugmentedTrajectory, + Augmenter, +) from .core import CLAMap ArrayTransform = Callable[[np.ndarray], np.ndarray] @@ -283,6 +298,7 @@ def __call__(self, t: Trajectory) -> Trajectory: result = t for mapping in reversed(self.submaps): result = mapping(result) + print(result.coords,result.forces) return result def __getitem__(self, idx: int, /) -> TMap: @@ -300,6 +316,56 @@ def astype(self, *args, **kwargs) -> "ComposedTMap": return self.__class__(submaps=new_maps) +_T_Coords = TypeVar("_T_Coords", bound=CoordsTrajectory) + + +class NullForcesTMap(TMap): + def __init__(self, warn_input_forces: bool = True, fill_value: Any = np.nan): + self.warn_input_forces = warn_input_forces + self.fill_value = fill_value + + def __call__( + self, + t: CoordsTrajectory, + ) -> Trajectory: + """Map Trajectory to new instance.""" + if isinstance(t, ForcesTrajectory): + if self.warn_input_forces: + warn("Discarding forces on input trajectory.", stacklevel=0) + + return Trajectory(coords=t.coords, forces=self.fill_value * t.coords) + + def map_arrays( + self, + coords: np.ndarray, + forces: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Map arrays using coord_map. + + This method mirrors + + Arguments: + --------- + coords: + forces: + + Returns: + ------- + mapped arrays + + """ + if forces is None: + t = CoordsTrajectory(coords=coords) + else: + t = Trajectory(coords=coords, forces=forces) + derived = self(t) + return (derived.coords, derived.forces) + + def astype(self,*args,**kwargs) -> "NullForcesTMap": + return self.__class__(warn_input_forces=self.warn_input_forces, + fill_value=self.fill_value) + + class RATMap: """Maps the real portions of an AugmentedTrajectory. diff --git a/src/aggforce/qp/__init__.py b/src/aggforce/qp/__init__.py index 73444f3..2ef2f31 100644 --- a/src/aggforce/qp/__init__.py +++ b/src/aggforce/qp/__init__.py @@ -18,6 +18,6 @@ try: from .jaxfeat import gb_feat - from .jgauss import joptgauss_map, stagedjoptgauss_map + from .jgauss import joptgauss_map, stagedjoptgauss_map, stagedjslicegauss_map except ImportError: pass diff --git a/src/aggforce/qp/basicagg.py b/src/aggforce/qp/basicagg.py index 2050b6e..0ae3251 100644 --- a/src/aggforce/qp/basicagg.py +++ b/src/aggforce/qp/basicagg.py @@ -3,13 +3,13 @@ from typing import Union from itertools import product import numpy as np -from ..trajectory import ForcesOnlyTrajectory +from ..trajectory import ForcesTrajectory from ..map import LinearMap, SeperableTMap from ..constraints import Constraints, reduce_constraint_sets def constraint_aware_uni_map( - traj: ForcesOnlyTrajectory, # noqa: ARG001 + traj: ForcesTrajectory, # noqa: ARG001 coord_map: LinearMap, constraints: Union[None, Constraints] = None, ) -> SeperableTMap: diff --git a/src/aggforce/qp/jgauss.py b/src/aggforce/qp/jgauss.py index fa96b73..4f356dc 100644 --- a/src/aggforce/qp/jgauss.py +++ b/src/aggforce/qp/jgauss.py @@ -1,18 +1,26 @@ """Provides jax methods for making optimized stochastic coordinate-force maps.""" from typing import Optional +import numpy as np from ..map import ( LinearMap, JLinearMap, AugmentedTMap, SeperableTMap, + NullForcesTMap, lmap_augvariables, ComposedTMap, RATMap, ) -from ..trajectory import Trajectory, AugmentedTrajectory, JCondNormal +from ..trajectory import ( + Trajectory, + CoordsTrajectory, + AugmentedTrajectory, + JCondNormal, +) from ..constraints import Constraints from .qplinear import qp_linear_map, DEFAULT_SOLVER_OPTIONS, SolverOptions +from .basicagg import constraint_aware_uni_map def joptgauss_map( @@ -102,7 +110,9 @@ def joptgauss_map( # however, it needs to in the form of jax function which acts on flattened # vectors and allows for single-frame operation. This is accessible via an # attributed of a jaxxed LinearMap (JLinearMap) attribute. - flattened_cmap = JLinearMap.from_linearmap(coord_map).flat_call + flattened_cmap = JLinearMap.from_linearmap( + coord_map, bypass_nan_check=True + ).flat_call # create the object that will do the noising augmenter = JCondNormal(cov=var, premap=flattened_cmap, seed=seed) # create extended trajectory using the derived noiser @@ -154,8 +164,9 @@ def stagedjoptgauss_map( 1. Generate an optimized force map without any noise. - Note: if force_map is specified, this is used in lieu of 1. 2. Create a augmented trajectory without any premap. - 3. Map the augmented trajectory using the non-noise optimized map. - 4. Create an optimized map on the mapped augmented trajectory. + 3. (Partially) map the real particles augmented trajectory using the non-noise + optimized map. + 4. Create an optimized map on the partially mapped augmented trajectory. 5. Compose the maps from 1 and 4 to create a new map. To access the premap, index the returned TMap with [1]. To obtain the noise @@ -218,8 +229,8 @@ def stagedjoptgauss_map( # We then extract the noise and coord maps and jaxify them. # # we know based on external knowledge that these entrees are LinearMaps - j_coord_map = JLinearMap.from_linearmap(pre_tmap.coord_map) # type: ignore [arg-type] - j_force_map = JLinearMap.from_linearmap(pre_tmap.force_map) # type: ignore [arg-type] + j_coord_map = JLinearMap.from_linearmap(pre_tmap.coord_map, bypass_nan_check=True) # type: ignore [arg-type] + j_force_map = JLinearMap.from_linearmap(pre_tmap.force_map, bypass_nan_check=True) # type: ignore [arg-type] # We then create the augmenter. This will be used with the full trajectory. augmenter = JCondNormal(cov=var, premap=j_coord_map.flat_call, seed=seed) @@ -298,3 +309,139 @@ def stagedjoptgauss_map( comb_tmap = ComposedTMap(submaps=[post_tmap, pre_tmap]) return comb_tmap + + +def stagedjslicegauss_map( + traj: CoordsTrajectory, + coord_map: LinearMap, + var: float, + kbt: float, + seed: Optional[int] = None, + constraints: Optional[Constraints] = None, # noqa: ARG001 + warn_input_forces: bool = True, +) -> ComposedTMap: + """Create Gaussian map which only uses information from noising in reported forces. + + This routine is written to mirror the procedure in stagedjoptgauss_map, and + similarly outputs a ComposedTMap; however, this ComposedTMap has 3 parts. + maps[2] adds null forces to the input data if needed (allowing the derived + tmap to operate when no forces are present), maps[1] maps the coordinates to the + coarse-grained resolution, and maps[0] noises the data and extracts noise-derived + forces. + + At the cost of increasing complexity, we keep the internal procedure close to + that in stagedjoptgauss_map. The following steps are performed: + 1. Set forces in input to nans to make sure they are not used. + 2. Create a augmented trajectory without any premap. + 3. Partially map the augmented trajectory to the resolution of mapped real + sites with augmented sites. + 4. Create a slice force map on the partially mapped trajectory. + 5. Compose the maps from 1, 3, and 4 to create a new map. + + Arguments: + --------- + traj: + Trajectory instance that will be used to create noised positions then + subsequently mapped. + coord_map: + Coordinate map representing the coarse-grained description of the system. The + output dimension (n_cg_sites) determines the number of auxiliary particles to + the Gaussian noise augmenter will add to the system. + + Note that this map does not enter the produced TMap in a straightforward way. + var: + The noise added is drawn from a Gaussian with a diagonal covariance matrix; this + positive scalar is the diagonal value. A small value means the level of noise + added is small, and larger values perturb the system more. + kbt: + Boltzmann constant times temperature for the samples in traj. This is needed to + turn the log density gradients of the added noise variates into forces. + seed: + Random seed that will be passed to the Gaussian noiser (JCondNormal instance). + constraints: + Not used. Retained for compatibility. + warn_input_forces: + If True, we warn if forces were provided in the input data, as we will ignore + them. + + Returns: + ------- + An ComposedTMap which characterizes the Gaussian map. This map has three + submaps; the first map adds dummy forces if the input data lacks forces. + The second map reduces the dimension of the data via coord_map, and the + third map noises the system and isolates the noise-derived forces. + + """ + # to be sure that we do not actually use force information present in the input + # trajectory, we replace it with NaNs. This also allows the input to + # not have force information without changing subsequent calls. + naforce_traj = NullForcesTMap(warn_input_forces=warn_input_forces)(traj) + + # Create augmenter that adds gaussian noise to the system. + # bypass_nan_check is needed for internal derivative calculations. + augmenter = JCondNormal( + cov=var, + premap=JLinearMap.from_linearmap( + coord_map, bypass_nan_check=True + ).flat_call, + seed=seed, + ) + # create augmented trajectory + aug_traj = AugmentedTrajectory.from_trajectory( + t=naforce_traj, augmenter=augmenter, kbt=kbt + ) + + # we now create the partially mapped augmented trajectory, but unlike in + # other methods we must create a dummy force map, and then use that to + # create a preprocessing tmap with coord_map. + null_fmap = LinearMap( + mapping=np.ones_like(coord_map.standard_matrix), handle_nans=False + ) + pre_tmap = SeperableTMap(coord_map=coord_map, force_map=null_fmap) + + # this contains the noise particles and mapped real particles. + pmapped_traj = RATMap(tmap=pre_tmap)(aug_traj) + + # create the map that isolates the noise sites on the partially mapped traj. + preserved_sites = [] + for index in range( + pmapped_traj.n_sites - aug_traj.n_aug_sites, pmapped_traj.n_sites + ): + preserved_sites.append([index]) + pmapped_coord_map = LinearMap( + mapping=preserved_sites, n_fg_sites=pmapped_traj.n_sites + ) + + # we then move to creating the force map that acts on the partially mapped + # traj. We no longer know what the constraints are (they have probably + # been mapped away). For a reasonable pre-coord map, there shouldn't be any + # left, and we assume this is true. + pmapped_tmap = constraint_aware_uni_map( + traj=pmapped_traj, + coord_map=pmapped_coord_map, + constraints=set(), + ) + + # this is the augmenter that acts on the already coarse-grained traj. + # As we do not use any forces on the real particles, we do not bother + # to create a force-modifier as is done in other methods. + pmapped_augmenter = JCondNormal( + cov=var, + seed=seed, + ) + + # we wrapped the derived force map with the augmentation operation + post_tmap = AugmentedTMap( + aug_tmap=pmapped_tmap, + augmenter=pmapped_augmenter, + kbt=kbt, + ) + + # and finally compose maps to create the returned callable. + # NullForcesTMap allows the resulting TMap to be applied to trajectories + # which do not have force information, or coordinate arrays. + comb_tmap = ComposedTMap( + submaps=[post_tmap, pre_tmap, NullForcesTMap(warn_input_forces=False)] + ) + + return comb_tmap diff --git a/src/aggforce/qp/qplinear.py b/src/aggforce/qp/qplinear.py index be396fe..de6809e 100644 --- a/src/aggforce/qp/qplinear.py +++ b/src/aggforce/qp/qplinear.py @@ -5,7 +5,7 @@ import numpy as np from qpsolvers import solve_qp # type: ignore [import-untyped] from ..map import LinearMap, SeperableTMap -from ..trajectory import ForcesOnlyTrajectory +from ..trajectory import ForcesTrajectory from ..constraints import Constraints, reduce_constraint_sets, constraint_lookup_dict SolverOptions = TypedDict( @@ -28,7 +28,7 @@ def qp_linear_map( - traj: ForcesOnlyTrajectory, + traj: ForcesTrajectory, coord_map: LinearMap, constraints: Union[None, Constraints] = None, l2_regularization: float = 0.0, diff --git a/src/aggforce/trajectory/__init__.py b/src/aggforce/trajectory/__init__.py index eaa6c67..16aedfa 100644 --- a/src/aggforce/trajectory/__init__.py +++ b/src/aggforce/trajectory/__init__.py @@ -1,7 +1,12 @@ """Provides tools and definitions of Trajectory instances.""" # __init__ doesn't use the imported objects # ruff: noqa: F401 -from .core import ForcesOnlyTrajectory, Trajectory, AugmentedTrajectory +from .core import ( + ForcesTrajectory, + CoordsTrajectory, + Trajectory, + AugmentedTrajectory, +) from .augment import Augmenter try: diff --git a/src/aggforce/trajectory/core.py b/src/aggforce/trajectory/core.py index 35f47af..a77b535 100644 --- a/src/aggforce/trajectory/core.py +++ b/src/aggforce/trajectory/core.py @@ -15,19 +15,14 @@ from .augment import Augmenter -_T_FTraj = TypeVar("_T_FTraj", bound="ForcesOnlyTrajectory") -_T_Traj = TypeVar("_T_Traj", bound="Trajectory") -_T_ATraj = TypeVar("_T_ATraj", bound="AugmentedTrajectory") - - -class ForcesOnlyTrajectory: +class ForcesTrajectory: r"""Trajectory with forces but without positions. - This is similar to Trajectory, but without forces. See Trajectory class for + This is similar to Trajectory, but without coordinates. See Trajectory class for more information. """ - def __init__(self, forces: np.ndarray) -> None: + def __init__(self, *, forces: np.ndarray) -> None: """Initialize. Arguments: @@ -57,10 +52,10 @@ def __len__(self) -> int: """Return the number of frames in the system.""" return len(self.forces) - def __getitem__(self, index: slice) -> "ForcesOnlyTrajectory": + def __getitem__(self, index: slice) -> "ForcesTrajectory": """Index trajectory. - Only slices are allowed. Returns a ForcesOnlyTrajectory instance. + Only slices are allowed. Returns a ForcesTrajectory instance. """ if not isinstance(index, slice): raise ValueError("Only slices are allowed for indexing.") @@ -69,12 +64,12 @@ def __getitem__(self, index: slice) -> "ForcesOnlyTrajectory": # we do not generically type this because unless it is overridden in a subclass, it # will indeed always return a - def copy(self) -> "ForcesOnlyTrajectory": + def copy(self) -> "ForcesTrajectory": """Copy a trajectory object.""" new_forces = self.forces.copy() return self.__class__(forces=new_forces) - def astype(self, *args, **kwargs) -> "ForcesOnlyTrajectory": + def astype(self, *args, **kwargs) -> "ForcesTrajectory": """Convert to a given dtype. Arguments are passed to np astype. Setting copy to False may reduce copies, but @@ -83,7 +78,70 @@ def astype(self, *args, **kwargs) -> "ForcesOnlyTrajectory": return self.__class__(forces=self.forces.astype(*args, **kwargs)) -class Trajectory(ForcesOnlyTrajectory): +class CoordsTrajectory: + r"""Trajectory with positions but without forces. + + This is similar to Trajectory, but without forces. See Trajectory class for + more information. + """ + + def __init__(self, *, coords: np.ndarray) -> None: + """Initialize. + + Arguments: + --------- + coords: + coordinates (positions) for multiple timesteps. + """ + if len(coords.shape) != 3: + raise ValueError("forces must have 3 dimensions.") + self.coords = coords + return + + @property + def n_sites(self) -> int: + """Number of particles in the system.""" + return self.coords.shape[1] + + @property + def n_dim(self) -> int: + """Dimension of the individual particles in the system. + + This is 3 in typical molecular dynamics applications. + """ + return self.coords.shape[2] + + def __len__(self) -> int: + """Return the number of frames in the system.""" + return len(self.coords) + + def __getitem__(self, index: slice) -> "CoordsTrajectory": + """Index trajectory. + + Only slices are allowed. Returns a CoordsTrajectory instance. + """ + if not isinstance(index, slice): + raise ValueError("Only slices are allowed for indexing.") + new_coords = self.coords[index] + return self.__class__(coords=new_coords) + + # we do not generically type this because unless it is overridden in a subclass, it + # will indeed always return a + def copy(self) -> "CoordsTrajectory": + """Copy a trajectory object.""" + new_coords = self.coords.copy() + return self.__class__(coords=new_coords) + + def astype(self, *args, **kwargs) -> "CoordsTrajectory": + """Convert to a given dtype. + + Arguments are passed to np astype. Setting copy to False may reduce copies, but + may return instances with shared references. + """ + return self.__class__(coords=self.coords.astype(*args, **kwargs)) + + +class Trajectory(CoordsTrajectory, ForcesTrajectory): r"""Collection of coordinates and forces from a molecular trajectory. A molecular dynamics simulation saves coordinates and forces at various @@ -116,7 +174,7 @@ class Trajectory(ForcesOnlyTrajectory): """ - def __init__(self, coords: np.ndarray, forces: np.ndarray) -> None: + def __init__(self, *, coords: np.ndarray, forces: np.ndarray) -> None: """Initialize. Arguments: @@ -130,8 +188,8 @@ def __init__(self, coords: np.ndarray, forces: np.ndarray) -> None: raise ValueError("coords and forces must be of same shape.") if len(coords.shape) != 3: raise ValueError("coords and forces must be of same shape.") - self.coords = coords - super().__init__(forces=forces) + CoordsTrajectory.__init__(self, coords=coords) + ForcesTrajectory.__init__(self, forces=forces) return def __getitem__(self, index: slice) -> "Trajectory": @@ -246,6 +304,7 @@ class AugmentedTrajectory(Trajectory): def __init__( self, + *, coords: np.ndarray, forces: np.ndarray, augmenter: Augmenter,