Skip to content

Commit

Permalink
Use of MBAR estim. errors in gather for single protocol repeats and c…
Browse files Browse the repository at this point in the history
…orrected raw report
  • Loading branch information
frannerin committed Jul 5, 2024
1 parent 13e8acb commit 3323df5
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def _generate_bad_legs_error_message(set_vals, ligpair):
def _parse_raw_units(results: dict) -> list[tuple]:
# grab individual unit results from master results dict
# returns list of (estimate, uncertainty) tuples
list_of_pur = list(results['protocol_result']['data'].values())[0]
list_of_pur = list(results['protocol_result']['data'].values())

return [(pu['outputs']['unit_estimate'],
pu['outputs']['unit_estimate_error'])
# could add to each tuple pu[0]["source_key"] for ID
return [(pu[0]['outputs']['unit_estimate'],
pu[0]['outputs']['unit_estimate_error'])
for pu in list_of_pur]


Expand Down Expand Up @@ -178,10 +179,10 @@ def _get_ddgs(legs, error_on_missing=True):
return DDGs


def _write_ddg(legs, writer, allow_partial):
def _write_ddg(legs, writer, allow_partial): # unc
DDGs = _get_ddgs(legs, error_on_missing=not allow_partial)
writer.writerow(["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, ligB, DDGbind, bind_unc, DDGhyd, hyd_unc in DDGs:
if DDGbind is not None:
DDGbind, bind_unc = format_estimate_uncertainty(DDGbind, bind_unc)
Expand All @@ -191,19 +192,19 @@ def _write_ddg(legs, writer, allow_partial):
writer.writerow([ligA, ligB, DDGhyd, hyd_unc])


def _write_raw(legs, writer, allow_partial=True):
writer.writerow(["leg", "ligand_i", "ligand_j", "DG(i->j) (kcal/mol)",
"MBAR uncertainty (kcal/mol)"])
def _write_raw(legs, writer, allow_partial=True): # *args?
writer.writerow(["leg", "repeat", "ligand_i", "ligand_j",
"DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"])

for ligpair, vals in sorted(legs.items()):
for simtype, repeats in sorted(vals.items()):
for m, u in repeats:
for rep, (m, u) in enumerate(repeats, 1):
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)

writer.writerow([simtype, *ligpair, m, u])
writer.writerow([simtype, rep, *ligpair, m, u])


def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
Expand All @@ -218,7 +219,7 @@ def _write_dg_raw(legs, writer, allow_partial): # pragma: no-cover
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_mle(legs, writer, allow_partial):
def _write_dg_mle(legs, writer, allow_partial): # unc
import networkx as nx
import numpy as np
from cinnabar.stats import mle
Expand Down Expand Up @@ -264,7 +265,7 @@ def _write_dg_mle(legs, writer, allow_partial):
MLEs.append((ligname, f, df))

writer.writerow(["ligand", "DG(MLE) (kcal/mol)",
"uncertainty (kcal/mol)"])
"uncertainty (kcal/mol)"]) # unc])
for ligA, DG, unc_DG in MLEs:
DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
writer.writerow([ligA, DG, unc_DG])
Expand Down Expand Up @@ -336,6 +337,9 @@ def gather(rootdir, output, report, allow_partial):
# 3) pair legs of simulations together into dict of dicts
legs = defaultdict(dict)

######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# MBAR_errors = True

for result_fn in result_fns:
result = load_results(result_fn)
if result is None:
Expand All @@ -344,6 +348,8 @@ def gather(rootdir, output, report, allow_partial):
click.echo(f"WARNING: Calculations for {result_fn} did not finish successfully!",
err=True)



try:
names = get_names(result)
except KeyError:
Expand All @@ -353,8 +359,15 @@ def gather(rootdir, output, report, allow_partial):
except KeyError:
simtype = legacy_get_type(result_fn)

raw_units = _parse_raw_units(result)
######## CHECK IF ALL RESULTS HAVE SAME # OF PROTOCOLUNITS?
# if MBAR_errors and len(raw_units) > 1:
# MBAR_errors = False

if report.lower() == 'raw':
legs[names][simtype] = _parse_raw_units(result)
legs[names][simtype] = raw_units
elif len(raw_units) == 1:
legs[names][simtype] = raw_units[0]
else:
legs[names][simtype] = result['estimate'], result['uncertainty']

Expand All @@ -364,6 +377,8 @@ def gather(rootdir, output, report, allow_partial):
lineterminator="\n", # to exactly reproduce previous, prefer "\r\n"
)

# unc = "MBAR uncertainty (kcal/mol)" if MBAR_errors else "uncertainty (kcal/mol)"

# 5a) write out MLE values
# 5b) write out DDG values
# 5c) write out each leg
Expand All @@ -373,7 +388,7 @@ def gather(rootdir, output, report, allow_partial):
# 'dg-raw': _write_dg_raw,
'raw': _write_raw,
}[report.lower()]
writing_func(legs, writer, allow_partial)
writing_func(legs, writer, allow_partial) # , unc)


PLUGIN = OFECommandPlugin(
Expand Down

0 comments on commit 3323df5

Please sign in to comment.