Skip to content

Commit

Permalink
Added no-noise gauss map method.
Browse files Browse the repository at this point in the history
Also added test based on previous gauss map tests.
  • Loading branch information
alekepd committed Apr 16, 2024
1 parent 5f9fd4e commit 576df95
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/aggforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

# in case jax is not installed
try:
from .qp import joptgauss_map, stagedjoptgauss_map, stagedjslicegauss_map
from .qp import (
joptgauss_map,
stagedjoptgauss_map,
stagedjslicegauss_map,
stagedjforcegauss_map,
)
except ImportError:
pass
7 changes: 6 additions & 1 deletion src/aggforce/qp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

try:
from .jaxfeat import gb_feat
from .jgauss import joptgauss_map, stagedjoptgauss_map, stagedjslicegauss_map
from .jgauss import (
joptgauss_map,
stagedjoptgauss_map,
stagedjslicegauss_map,
stagedjforcegauss_map,
)
except ImportError:
pass
205 changes: 205 additions & 0 deletions src/aggforce/qp/jgauss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provides jax methods for making optimized stochastic coordinate-force maps."""

from typing import Optional
import warnings
import numpy as np
from ..map import (
LinearMap,
Expand Down Expand Up @@ -443,3 +444,207 @@ def stagedjslicegauss_map(
)

return comb_tmap


def stagedjforcegauss_map(
traj: Trajectory,
coord_map: LinearMap,
var: float,
kbt: float,
force_map: Optional[LinearMap] = None,
constraints: Optional[Constraints] = None,
seed: Optional[int] = None,
premap_l2_regularization: float = 0.0,
premap_solver_args: SolverOptions = DEFAULT_SOLVER_OPTIONS,
contribution_tolerance: float = 1e-6,
**kwargs,
) -> ComposedTMap:
"""Create source-force-only Gaussian map with linear premap and subsequent noising.
This routine creates a trajectory map that noises coordinates, and then introduces
a minimal amount of noise-derived force information into the induced force signal.
For many conditions the level of noise-derived force may be brought to zero.
This routine performs the following steps:
1. Generate an optimized force map without any noise.
- Note: if force_map is specified, this is used in lieu of 1.
2. Create a augmented trajectory without any premap.
3. (Partially) map the real particles augmented trajectory using the non-noise
optimized map.
4. Create an optimized map on the partially mapped augmented trajectory.
- Critically, this map is optimized not to reduce all force noise but reduce
noise-only force noise.
5. Compose the maps from 1 and 4 to create a new map.
To access the premap, index the returned TMap with [1]. To obtain the noise
map, index it with [0]. Note that due to the nature of the extended ensemble,
the force-map will still mix forces from the virtual and real particles; however,
the resulting force should be close to the real-particle force present before
creating the extended ensemble.
This method is structured to mirror stagedjoptgauss_map.
Arguments:
---------
traj:
Trajectory instance that will be used to create the optimized force map and
then subsequently mapped.
coord_map:
Coordinate map representing the coarse-grained description of the system. The
output dimension (n_cg_sites) determines the number of auxiliary particles to
the Gaussian noise augmenter will add to the system.
Note that this map does not enter the produced TMap in a straightforward way.
var:
The noise added is drawn from a Gaussian with a diagonal covariance matrix; this
positive scalar is the diagonal value. A small value means the level of noise
added is small, and larger values perturb the system more.
kbt:
Boltzmann constant times temperature for the samples in traj. This is needed to
turn the log density gradients of the added noise variates into forces. If all
force-related contributions are removed, this should not affect mapped forces.
force_map:
If not None, this is used instead of performing the initial no-noise force map
optimization.
constraints:
Molecular constraints present in traj's simulator. Used in force map design.
seed:
Random seed that will be passed to the Gaussian noiser (JCondNormal instance).
premap_l2_regularization:
l2_regularization passed to initial non-noised force map optimization
(qp_linear_map).
premap_solver_args:
Arguments passed to initial non-noised force map optimization.
contribution_tolerance:
We check the mean l2-norm of the noise-derived force contribution over the given
trajectory and compare it to this value; if it is larger, we warn.
**kwargs:
Passed to underlying qp_linear_map optimization on the derived
AugmentedTrajectory.
Returns:
-------
An ComposedTMap which characterizes the Gaussian map. This map has two submaps; the
first is a deterministic map that coarse-grains the coordinates and forces,
and the second map applies noising operations. Data may be mapped with the first
map, saved, loaded, and then mapped with the second map.
"""
# first create non-noised optimized force map
if force_map is None:
pre_tmap = qp_linear_map(
traj=traj,
coord_map=coord_map,
constraints=constraints,
l2_regularization=premap_l2_regularization,
solver_args=premap_solver_args,
)
else:
pre_tmap = SeperableTMap(coord_map=coord_map, force_map=force_map)

# We then extract the noise and coord maps and jaxify them.
#
# we know based on external knowledge that these entrees are LinearMaps
j_coord_map = JLinearMap.from_linearmap(pre_tmap.coord_map, bypass_nan_check=True) # type: ignore [arg-type]
j_force_map = JLinearMap.from_linearmap(pre_tmap.force_map, bypass_nan_check=True) # type: ignore [arg-type]

# We then create the augmenter. This will be used with the full trajectory.
augmenter = JCondNormal(cov=var, premap=j_coord_map.flat_call, seed=seed)

# When we optimize the second-resolution map, we want to miniminize only
# the noise contributions. In order to do that we remove all real-force
# contributions and then then optimize the remaining forces as normal.
# zeroforce_traj zeros out the trajectory force contributions.
zeroforce_traj = Trajectory(coords=traj.coords, forces=np.zeros_like(traj.forces))

# This trajectory contains all the source atoms and noise sites, and only
# has noise-force contributions.
aug_traj = AugmentedTrajectory.from_trajectory(
t=zeroforce_traj, augmenter=augmenter, kbt=kbt
)

# We map this trajectory to only have the noise sites AND the coarse-grained version
# of the real sites. This allows us to map the noise forces on the real sites using
# the pre-derived force map.
pmapped_traj = RATMap(tmap=pre_tmap)(aug_traj)

# we now work towards optimizing a force map on pmapped_traj that minimizes
# noise-force contributions.

# pmapped* is a Trajectory, not an AugmentedTrajectory. So we must manually
# create the coordinate map for the second optimization. This coordinate
# map isolates the noise particles, similar to that created by lmap_augvariables.
preserved_sites = []
for index in range(
pmapped_traj.n_sites - aug_traj.n_aug_sites, pmapped_traj.n_sites
):
preserved_sites.append([index])
pmapped_coord_map = LinearMap(
mapping=preserved_sites, n_fg_sites=pmapped_traj.n_sites
)

# we then move to creating the force map. we no longer know what the
# constraints are (they have probably been mapped away). For a reasonable
# pre-coord map, there shouldn't be any left, and we assume this is true.
pmapped_tmap = qp_linear_map(
traj=pmapped_traj, coord_map=pmapped_coord_map, constraints=set(), **kwargs
)
# this derived force map is treated as a general force map that may have noise
# contributions.

# we check how big the noise contributions are
remaining_force_residual = np.mean(pmapped_tmap(pmapped_traj).forces**2)
if remaining_force_residual > contribution_tolerance:
warnings.warn(
"Unable to remove all noise contributions in forces. Remaining "
f"contribution: {remaining_force_residual}.",
stacklevel=0,
)

# we now create the composed map.
#
# (j_force_map @ j_coord_map.T) is an important object. It is the
# JLinearMap given given by the standard_matrix in both instances being
# appropriately multiplied. This can be see by first noting that \grad_x
# f(A x) = A^T [\grad f] (A x); in our application, j_coord_map corresponds
# to A, and [\grad f] is the force calculated w.r.t. the CG particles. A^T
# [\grad f] (h) therefore is an expression for the atomistic forces given
# by the noise when only the coarse-grained positions are provided (h).
# This expression is then mapped via j_force_map.
#
# Collectively, (j_force_map @ j_coord_map.T) transforms the CG real
# coordinate forces given by the Augmenter into mapped atomistic forces. As
# a result, the following Augmenter now corrects the forces in a
# _already mapped_ atomistic trajectory.

pmapped_augmenter = JCondNormal(
cov=var,
source_postmap=(j_force_map @ j_coord_map.T),
seed=seed,
)

post_tmap = AugmentedTMap(
aug_tmap=pmapped_tmap,
augmenter=pmapped_augmenter,
kbt=kbt,
)

# comb_tmap, when applied to a source trajectory will perform the following steps:
# 1. Apply pre_tmap
# The incoming data is not yet coarse-grained. pre_tmap will use the derived
# linear force and coordinate maps to map the data to the CG resolution.
# X = j_coord_map(x) #noqa
# Y = j_force_map(y) #noqa
# ^ y is the atomistic force, x the atomistic coords.
# 2. Apply post_tmap
# X_final = X+[0-mean gaussian noise]
# Y_final = Y+[(j_force_map @ j_coord_map.T)([noise-force])] #noqa
# note that by substitution, linearity, and @, we get
# Y_final = j_force_map(y + j_coord_map.T*[noise-force]) #noqa
# ^this is the "backmapped" noise force
# ^which is then combined with the atomistic
# force and mapped

comb_tmap = ComposedTMap(submaps=[post_tmap, pre_tmap])

return comb_tmap
78 changes: 78 additions & 0 deletions tests/test_gaussmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
joptgauss_map,
stagedjoptgauss_map,
stagedjslicegauss_map,
stagedjforcegauss_map,
)
from aggforce.agg import TMAP_KNAME
from aggforce import jaxmapval as mv
Expand Down Expand Up @@ -404,3 +405,80 @@ def test_negative_cln025_sepgauss_mscg_ip(seed: int = rseed) -> None:
)
# check that we do not match
assert not np.allclose(KNOWN_PROJS, np.array(gauss_projs), atol=2e-1)


@pytest.mark.jax
def test_cln025_sepforcegauss_mscg_ip(seed: int = rseed) -> None:
r"""Check if CLN025 seperable force-only gauss maps produce known results.
This checks for consistency against previous results, but not correctness.
This test is stochastic. It should rarely fail, but it can. We keep it stochastic
as small design changes may alter seed dependence.
See tests in test_forces for more information.
"""
from aggforce import jaxmapval as mv

coords, forces, pdb, kbt = get_data()
# cmap is the configurational coarse-grained map
cmap = gen_config_map(pdb, "CA$")
# guess molecular constraints
constraints = guess_pairwise_constraints(coords[0:10], threshold=1e-3)

train_coords = coords[:500]
test_coords = coords[500:]

train_forces = forces[:500]
test_forces = forces[500:]

# we do NOT set the rng here.
gauss_results = project_forces(
coords=train_coords,
forces=train_forces,
coord_map=cmap,
constrained_inds=constraints,
method=stagedjforcegauss_map,
var=0.002,
kbt=kbt,
)

# map multiple times with gauss map to make big generated dataset
mapped_coords = []
mapped_forces = []
for _ in range(300):
gauss_coords, gauss_forces = gauss_results[TMAP_KNAME].map_arrays(
test_coords, test_forces
)
mapped_coords.append(gauss_coords)
mapped_forces.append(gauss_coords)

all_mapped_coords = np.concatenate(mapped_coords, axis=0)
all_mapped_forces = np.concatenate(mapped_coords, axis=0)

# project onto random bases. We here give an rng so that we get the same
# projections.
gauss_projs = mv.random_force_proj( # type: ignore
coords=all_mapped_coords,
forces=all_mapped_forces,
randg=r.default_rng(seed=seed),
n_samples=5,
inner=6.0,
outer=12.0,
width=6.0,
average=False,
)
# these were taking from the non seperable gauss test, but also
# be matched here.
KNOWN_PROJS: Final = np.array(
[
86.73444366455078,
-87.3666763305664,
70.80025482177734,
64.30303955078125,
-25.622215270996094,
]
)
assert np.allclose(KNOWN_PROJS, np.array(gauss_projs), atol=2e-1)


0 comments on commit 576df95

Please sign in to comment.