Skip to content

Commit

Permalink
Merge pull request #86 from dwhswenson/tis-compiling
Browse files Browse the repository at this point in the history
TIS compiling
  • Loading branch information
dwhswenson authored Aug 24, 2024
2 parents e996e00 + e3bd443 commit de53902
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 115 deletions.
108 changes: 60 additions & 48 deletions paths_cli/compiling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,55 +38,58 @@
description="final state for this transition",
)

def mistis_trans_info_param_builder(dcts):
default = 'volume-interface-set' # TODO: make this flexible?
trans_info = []
volume_compiler = compiler_for("volume")
interface_set_compiler = compiler_for('interface_set')
for dct in dcts:
dct = dct.copy()
dct['type'] = dct.get('type', default)
initial_state = volume_compiler(dct.pop('initial_state'))
final_state = volume_compiler(dct.pop('final_state'))
interface_set = interface_set_compiler(dct)
trans_info.append((initial_state, interface_set, final_state))

return trans_info


MISTIS_INTERFACE_SETS_PARAM = Parameter(
'interface_sets', mistis_trans_info_param_builder,
json_type=json_type_list(json_type_ref('interface-set')),
description='interface sets for MISTIS'
)

build_interface_set = InterfaceSetPlugin(
# this is reused in the simple single TIS setup
VOLUME_INTERFACE_SET_PARAMS = [
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
description=("the collective variable for this interface "
"set")),
Parameter('minvals', custom_eval,
json_type=json_type_list(json_type_eval("Float")),
description=("minimum value(s) for interfaces in this"
"interface set")),
Parameter('maxvals', custom_eval,
json_type=json_type_list(json_type_eval("Float")),
description=("maximum value(s) for interfaces in this"
"interface set")),
]


VOLUME_INTERFACE_SET_PLUGIN = InterfaceSetPlugin(
builder=Builder('openpathsampling.VolumeInterfaceSet'),
parameters=[
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
description=("the collective variable for this interface "
"set")),
Parameter('minvals', custom_eval,
json_type=json_type_list(json_type_eval("Float")),
description=("minimum value(s) for interfaces in this"
"interface set")),
Parameter('maxvals', custom_eval,
json_type=json_type_list(json_type_eval("Float")),
description=("maximum value(s) for interfaces in this"
"interface set")),
],
name='interface-set',
parameters=VOLUME_INTERFACE_SET_PARAMS,
name='volume-interface-set',
description="Interface set used in transition interface sampling.",
)


def mistis_trans_info(dct):
dct = dct.copy()
transitions = dct.pop('transitions')
volume_compiler = compiler_for('volume')
trans_info = [
(
volume_compiler(trans['initial_state']),
build_interface_set(trans['interfaces']),
volume_compiler(trans['final_state'])
)
for trans in transitions
]
dct['trans_info'] = trans_info
dct['trans_info'] = dct.pop('interface_sets')
return dct


def tis_trans_info(dct):
# remap TIS into MISTIS format
dct = dct.copy()
initial_state = dct.pop('initial_state')
final_state = dct.pop('final_state')
interface_set = dct.pop('interfaces')
dct['transitions'] = [{'initial_state': initial_state,
'final_state': final_state,
'interfaces': interface_set}]
return mistis_trans_info(dct)


TPS_NETWORK_PLUGIN = NetworkCompilerPlugin(
builder=Builder('openpathsampling.TPSNetwork'),
parameters=[INITIAL_STATES_PARAM, FINAL_STATES_PARAM],
Expand All @@ -96,18 +99,27 @@ def tis_trans_info(dct):
)


# MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
# parameters=[Parameter('trans_info', mistis_trans_info)],
# builder=Builder('openpathsampling.MISTISNetwork'),
# name='mistis'
# )
MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
parameters=[MISTIS_INTERFACE_SETS_PARAM],
builder=Builder('openpathsampling.MISTISNetwork',
remapper=mistis_trans_info),
name='mistis'
)

def single_tis_builder(initial_state, final_state, cv, minvals, maxvals):
import openpathsampling as paths
interface_set = paths.VolumeInterfaceSet(cv, minvals, maxvals)
return paths.MISTISNetwork([
(initial_state, interface_set, final_state)
])

TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
builder=single_tis_builder,
parameters=([INITIAL_STATE_PARAM, FINAL_STATE_PARAM]
+ VOLUME_INTERFACE_SET_PARAMS),
name='tis'
)

# TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
# builder=Builder('openpathsampling.MISTISNetwork'),
# parameters=[Parameter('trans_info', tis_trans_info)],
# name='tis'
# )

# old names not yet replaced in testing THESE ARE WHY WE'RE DOUBLING! GET
# RID OF THEM! (also, use an is-check)
Expand Down
7 changes: 7 additions & 0 deletions paths_cli/compiling/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,11 @@ def __call__(self, **dct):
"that type (i.e., ``OrganizeByMoveGroupStrategy``)"),
)

DEFAULT_TIS_SCHEME_PLUGIN = SchemeCompilerPlugin(
builder=Builder('openpathsampling.DefaultScheme'),
parameters=[NETWORK_PARAMETER, ENGINE_PARAMETER],
name='default-tis',
description="",
)

SCHEME_COMPILER = CategoryPlugin(SchemeCompilerPlugin, aliases=['schemes'])
125 changes: 58 additions & 67 deletions paths_cli/tests/compiling/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,74 +10,22 @@

_COMPILERS_LOC = 'paths_cli.compiling.root_compiler._COMPILERS'


def check_unidirectional_tis(results, state_A, state_B, cv):
assert len(results) == 1
trans_info = results['trans_info']
assert len(trans_info) == 1
assert len(trans_info[0]) == 3
trans = trans_info[0]
assert isinstance(trans, tuple)
assert trans[0] == state_A
assert trans[2] == state_B
assert isinstance(trans[1], paths.VolumeInterfaceSet)
ifaces = trans[1]
assert ifaces.cv == cv
assert ifaces.minvals == float("-inf")
np.testing.assert_allclose(ifaces.maxvals,
[0, np.pi / 10.0, np.pi / 5.0])


def test_mistis_trans_info(cv_and_states):
cv, state_A, state_B = cv_and_states
dct = {
'transitions': [{
'initial_state': "A",
'final_state': "B",
'interfaces': {
'cv': 'cv',
'minvals': 'float("-inf")',
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
}
}]
}
patch_base = 'paths_cli.compiling.networks'
compiler = {
'cv': mock_compiler('cv', named_objs={'cv': cv}),
'volume': mock_compiler('volume', named_objs={
"A": state_A, "B": state_B
}),
}
with mock.patch.dict(_COMPILERS_LOC, compiler):
results = mistis_trans_info(dct)

check_unidirectional_tis(results, state_A, state_B, cv)
@pytest.fixture
def unidirectional_tis_compiler(cv_and_states):
paths.InterfaceSet._reset()


def test_tis_trans_info(cv_and_states):
cv, state_A, state_B = cv_and_states
dct = {
'initial_state': "A",
'final_state': "B",
'interfaces': {
'cv': 'cv',
'minvals': 'float("-inf")',
'maxvals': 'np.array([0, 0.1, 0.2]) * np.pi',
}
}

compiler = {
return {
'cv': mock_compiler('cv', named_objs={'cv': cv}),
'volume': mock_compiler('volume', named_objs={
"A": state_A, "B": state_B
}),
'interface_set': mock_compiler(
'interface_set',
type_dispatch={
'volume-interface-set': VOLUME_INTERFACE_SET_PLUGIN
}
),
}
with mock.patch.dict(_COMPILERS_LOC, compiler):
results = tis_trans_info(dct)

check_unidirectional_tis(results, state_A, state_B, cv)
paths.InterfaceSet._reset()


def test_build_tps_network(cv_and_states):
Expand All @@ -86,17 +34,60 @@ def test_build_tps_network(cv_and_states):
dct = yaml.load(yml, yaml.FullLoader)
compiler = {
'volume': mock_compiler('volume', named_objs={"A": state_A,
"B": state_B}),
"B": state_B}),
}
with mock.patch.dict(_COMPILERS_LOC, compiler):
network = build_tps_network(dct)
network = TPS_NETWORK_PLUGIN(dct)
assert isinstance(network, paths.TPSNetwork)
assert len(network.initial_states) == len(network.final_states) == 1
assert network.initial_states[0] == state_A
assert network.final_states[0] == state_B

def test_build_mistis_network():
pytest.skip()

def test_build_tis_network():
pytest.skip()
def test_build_mistis_network(cv_and_states, unidirectional_tis_compiler):
cv, state_A, state_B = cv_and_states
mistis_dict = {
'interface_sets': [
{
'initial_state': "A",
'final_state': "B",
'cv': 'cv',
'minvals': 'float("-inf")',
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
},
{
'initial_state': "B",
'final_state': "A",
'cv': 'cv',
'minvals': "np.array([1.0, 0.9, 0.8])",
'maxvals': "float('inf')",
}
]
}

with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
network = MISTIS_NETWORK_PLUGIN(mistis_dict)

assert isinstance(network, paths.MISTISNetwork)
assert len(network.sampling_transitions) == 2
assert len(network.transitions) == 2
assert list(network.transitions) == [(state_A, state_B),
(state_B, state_A)]

def test_build_tis_network(cv_and_states, unidirectional_tis_compiler):
cv, state_A, state_B = cv_and_states
tis_dict = {
'initial_state': "A",
'final_state': "B",
'cv': "cv",
'minvals': 'float("inf")',
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi",
}

with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
network = TIS_NETWORK_PLUGIN(tis_dict)

assert isinstance(network, paths.MISTISNetwork)
assert len(network.sampling_transitions) == 1
assert len(network.transitions) == 1
assert list(network.transitions) == [(state_A, state_B)]

0 comments on commit de53902

Please sign in to comment.