Skip to content

Commit

Permalink
Added method for using only noise-derived forces.
Browse files Browse the repository at this point in the history
stagedjslicegauss_map makes a augmenting map similar to previous
methods, but only uses coord information from the input trajectory.

Modified Trajectory definitions along the way: added coord-only
trajectory, and changed coord/force trajectory names.
  • Loading branch information
alekepd committed Mar 21, 2024
1 parent 0408f94 commit e59b9dc
Show file tree
Hide file tree
Showing 11 changed files with 480 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/aggforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@

# in case jax is not installed
try:
from .qp import joptgauss_map, stagedjoptgauss_map
from .qp import joptgauss_map, stagedjoptgauss_map, stagedjslicegauss_map
except ImportError:
pass
10 changes: 9 additions & 1 deletion src/aggforce/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
# __init__ doesn't use the imported objects
# ruff: noqa: F401
from .core import LinearMap, CLAMap, trjdot
from .tmap import TMap, SeperableTMap, CLAFTMap, AugmentedTMap, ComposedTMap, RATMap
from .tmap import (
TMap,
SeperableTMap,
CLAFTMap,
AugmentedTMap,
ComposedTMap,
NullForcesTMap,
RATMap,
)
from .tools import lmap_augvariables, smear_map

# in case jax is not installed
Expand Down
84 changes: 73 additions & 11 deletions src/aggforce/map/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
positions or forces.
"""

from typing import Union, List, Callable, Final, Dict, Optional
from typing import Union, List, Callable, Final, Dict, Optional, Literal
import numpy as np
from ..util import trjdot


# this should be very fast
def _has_nans(x: np.ndarray) -> bool:
flat = x.ravel(order="K")
# this will be a 1 element array
return bool(np.isnan(np.dot(flat, flat)))


# _Taggable is only used by CLAMap right now, but its a separate class
# to keep things clear.
class _Taggable:
Expand Down Expand Up @@ -60,6 +67,8 @@ def __init__(
self,
mapping: Union[List[List[int]], np.ndarray],
n_fg_sites: Union[int, None] = None,
handle_nans: Union[bool, Literal["safe"]] = True,
nan_check_threshold: float = 1e-6,
) -> None:
r"""Initialize LinearMapping from something describing a map.
Expand All @@ -78,8 +87,19 @@ def __init__(
Certain mapping descriptions make it ambiguous how many total
fine-grained sites there are. This variable allows this ambiguity to
be resolved.
tags (dictionary or None):
Passed to Map init.
handle_nans:
If true, np.nans in the matrices given as input to class calls are treated
in a special way:
1) They are converted to np.inf
2) If any inf values exist in the output, an exception is raised
3) All nans in the output are set to 0
0*inf = nan, but all other numbers satisfy c*inf=+-inf; as a result, this
procedure allows input arrays that have nans to be operated on such that
if 0 is multiplied with that entry, 0 is returned.
If safe, this is also done, but we make sure that no temporary modifications
are performed in the input matrix (else, we may temporarily in-place set
Nan to Inf). If False, simple matrix multiplication is performed without
NaN specific logic.
Example:
-------
Expand Down Expand Up @@ -122,6 +142,15 @@ def __init__(
else:
raise ValueError("Cannot understanding mapping f{mapping}.")

self.handle_nans = handle_nans
if self.handle_nans:
if not np.all(np.isfinite(self.standard_matrix)):
raise ValueError(
"Nan checking can only be performed in "
"standard_matrix is itself finite."
)
self.nan_check_threshold = nan_check_threshold

@property
def standard_matrix(self) -> np.ndarray:
r"""The mapping in standard matrix format."""
Expand Down Expand Up @@ -161,14 +190,37 @@ def __call__(
Arguments:
---------
points (np.ndarray):
Assumed to be 3 dimensional of shape (n_steps,n_sites,n_dims).
Assumed to be 3 dimensional of shape (n_steps,n_sites,n_dims). Note that
if self.handle_nans is True, this array may be temporarily altered if
it contains NaN values.
Returns:
-------
Combines points along the n_sites dimension according to the internal
map.
map. Note that the handling of NaNs depends on initialization options.
"""
return trjdot(points, self.standard_matrix)
nan_handling = self.handle_nans and _has_nans(points)
if nan_handling:
input_mask = np.isnan(points)
if self.handle_nans == "safe":
input_matrix = points.copy()
else:
input_matrix = points
input_matrix[input_mask] = 0.0
raw_result = trjdot(input_matrix, self.standard_matrix)
input_matrix[input_mask] = -1.0
pushed_result = trjdot(input_matrix, self.standard_matrix)
if not np.allclose(
raw_result, pushed_result, atol=self.nan_check_threshold
):
raise ValueError(
"NaN handling is on and results seem to depend on NaN "
"positions in input array. Check input and standard_matrix."
)
input_matrix[input_mask] = np.nan
return raw_result
else:
return trjdot(points, self.standard_matrix)

def flat_call(self, flattened: np.ndarray) -> np.ndarray:
"""Apply map to pre-flattened array.
Expand Down Expand Up @@ -204,19 +256,27 @@ def flat_call(self, flattened: np.ndarray) -> np.ndarray:
@property
def T(self) -> "LinearMap":
"""LinearMap defined by transpose of its standard matrix."""
return LinearMap(mapping=self.standard_matrix.T)
return LinearMap(mapping=self.standard_matrix.T,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __matmul__(self, lm: "LinearMap", /) -> "LinearMap":
"""LinearMap defined by multiplying the standard_matrix's of arguments."""
return LinearMap(mapping=self.standard_matrix @ lm.standard_matrix)
return LinearMap(mapping=self.standard_matrix @ lm.standard_matrix,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __rmul__(self, c: float, /) -> "LinearMap":
"""LinearMap defined by multiplying the standard_matrix's with a coefficient."""
return LinearMap(mapping=c * self.standard_matrix)
return LinearMap(mapping=c * self.standard_matrix,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __add__(self, lm: "LinearMap", /) -> "LinearMap":
"""LinearMap defined by adding standard_matrices."""
return LinearMap(mapping=self.standard_matrix + lm.standard_matrix)
return LinearMap(mapping=self.standard_matrix + lm.standard_matrix,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def astype(self, *args, **kwargs) -> "LinearMap":
"""Convert to a given precision as determined by arguments.
Expand All @@ -225,7 +285,9 @@ def astype(self, *args, **kwargs) -> "LinearMap":
instance. Arguments are passed to np astype. Setting copy to False may
reduce copies, but may return instances with shared references.
"""
return self.__class__(mapping=self.standard_matrix.astype(*args, **kwargs))
return self.__class__(mapping=self.standard_matrix.astype(*args, **kwargs),
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)


class CLAMap(_Taggable):
Expand Down
115 changes: 90 additions & 25 deletions src/aggforce/map/jaxlinearmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Extends LinearMaps for Jax operations."""
from typing import overload, TypeVar
from jax import Array
from typing import overload, TypeVar, Tuple, Union
from functools import partial
from jax import Array, jit
import jax.numpy as jnp
from numpy.typing import NDArray
import numpy as np
Expand All @@ -10,20 +11,60 @@
ArrT = TypeVar("ArrT", NDArray, Array)


@partial(jit, static_argnames="nan_handling")
def _trjdot_worker(
factor: Array, points: Array, nan_handling: bool
) -> Tuple[Array, Array]:
"""Help apply internal trjdot transforms.
If nan_handling is false, applies trjdot and returns
a tuple with both entries the same result. If true, the first
entry in the tuple is the result of setting nans to 0, and the second
result is setting nans to 1.
"""
if nan_handling:
input_matrix_0 = jnp.nan_to_num(
points,
nan=0.0,
)
input_matrix_1 = jnp.nan_to_num(
points,
nan=1.0,
)
result_0 = jtrjdot(input_matrix_0, factor)
result_1 = jtrjdot(input_matrix_1, factor)
return (result_0, result_1)
else:
result = jtrjdot(points, factor)
return (result, result)


class JLinearMap(LinearMap):
"""Extends LinearMaps to map Jax arrays."""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, bypass_nan_check: bool = False, **kwargs) -> None:
"""Initialize.
All argments are passed via super().
Arguments:
---------
*args:
Passed via super().
bypass_nan_check:
If true, we check to see if infs were generated when mapping matrices
with a nan check (similar to LinearMap behavior). If not, we do not;
this often must be set to false to be wrapped in a jit call.
**kwargs:
Passed via super().
"""
super().__init__(*args, **kwargs)
self.bypass_nan_check = bypass_nan_check
self._jax_standard_matrix = jnp.asarray(self.standard_matrix)

@property
def jax_standard_matrix(self) -> Array:
"""Return standard_matrix as a Jax array."""
return jnp.asarray(self.standard_matrix)
return self._jax_standard_matrix

@overload
def __call__(self, points: NDArray) -> NDArray:
Expand All @@ -33,7 +74,7 @@ def __call__(self, points: NDArray) -> NDArray:
def __call__(self, points: Array) -> Array:
...

def __call__(self, points: ArrT) -> ArrT:
def __call__(self, points: Union[NDArray, Array]) -> Union[NDArray, Array]:
r"""Apply map to a particular form of 3-dim array.
Arguments:
Expand All @@ -51,26 +92,32 @@ def __call__(self, points: ArrT) -> ArrT:
Notes:
-----
This implementation is effectively identical to that in the parent class,
but uses Jax operations.
but will behave differently if an invalid map is applied to
"""
if isinstance(points, np.ndarray):
numpy_input = True
jpoints = jnp.asarray(points)
else:
numpy_input = False
jpoints = points
transformed = jtrjdot(jpoints, self.jax_standard_matrix)
if isinstance(points, np.ndarray):
return np.asarray(transformed)
else:
return transformed

@overload
def flat_call(self, flattened: NDArray) -> NDArray:
...

@overload
def flat_call(self, flattened: Array) -> Array:
...
result, sec_result = _trjdot_worker(
factor=self.jax_standard_matrix,
points=jpoints,
nan_handling=self.handle_nans,
)
if (not self.bypass_nan_check) and self.handle_nans:
if not jnp.allclose(result, sec_result, atol=self.nan_check_threshold):
raise ValueError(
"NaN handling is on and multiplication tried to use "
"a NaN value. Check the input array and "
"standard_matrix."
)
if numpy_input:
return np.asarray(result)
else:
return result

def flat_call(self, flattened: ArrT) -> ArrT:
"""Apply map to pre-flattened array.
Expand Down Expand Up @@ -111,21 +158,39 @@ def flat_call(self, flattened: ArrT) -> ArrT:
@property
def T(self) -> "JLinearMap":
"""LinearMap defined by transpose of its standard matrix."""
return JLinearMap(mapping=self.standard_matrix.T)
return JLinearMap(mapping=self.standard_matrix.T,
bypass_nan_check=self.bypass_nan_check,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __matmul__(self, lm: "LinearMap", /) -> "JLinearMap":
"""LinearMap defined by multiplying the standard_matrix's of arguments."""
return JLinearMap(mapping=self.standard_matrix @ lm.standard_matrix)
return JLinearMap(mapping=self.standard_matrix @ lm.standard_matrix,
bypass_nan_check=self.bypass_nan_check,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __rmul__(self, c: float, /) -> "JLinearMap":
"""LinearMap defined by multiplying the standard_matrix's with a coefficient."""
return JLinearMap(mapping=c * self.standard_matrix)
return JLinearMap(mapping=c * self.standard_matrix,
bypass_nan_check=self.bypass_nan_check,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

def __add__(self, lm: "LinearMap", /) -> "JLinearMap":
"""LinearMap defined by adding standard_matrices."""
return JLinearMap(mapping=self.standard_matrix + lm.standard_matrix)
return JLinearMap(mapping=self.standard_matrix + lm.standard_matrix,
bypass_nan_check=self.bypass_nan_check,
handle_nans=self.handle_nans,
nan_check_threshold=self.nan_check_threshold)

@classmethod
def from_linearmap(cls, lm: LinearMap, /) -> "JLinearMap":
def from_linearmap(
cls, lm: LinearMap, /, bypass_nan_check: bool = False
) -> "JLinearMap":
"""Create JLinearMap from LinearMap."""
return JLinearMap(mapping=lm.standard_matrix)
return JLinearMap(
mapping=lm.standard_matrix,
bypass_nan_check=bypass_nan_check,
handle_nans=lm.handle_nans,
)
Loading

0 comments on commit e59b9dc

Please sign in to comment.