Skip to content

Commit

Permalink
Merge branch 'master' into retro-tab
Browse files Browse the repository at this point in the history
  • Loading branch information
dotsdl authored Jun 15, 2021
2 parents cea6944 + 7933e26 commit 8c0d75f
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 59 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,11 @@ jobs:
df -h
ulimit -a
# More info on options: https://github.com/goanpeca/setup-miniconda
- uses: goanpeca/setup-miniconda@v1
- name: Configure conda
uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
environment-file: devtools/conda-envs/test_env.yaml

channels: conda-forge,defaults,omnia

activate-environment: test
auto-update-conda: true
auto-activate-base: false
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Tools and infrastructure for automated compound discovery using Folding@home.
Run transformation and compound free energy analysis, producing `results/analysis.json`:

``` sh
fah-xchem run-analysis
fah-xchem run-analysis \
--compound-series-file compound-series.json \
--config-file config.json \
--fah-projects-dir /path/to/projects/ \
Expand Down
3 changes: 3 additions & 0 deletions fah_xchem/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def generate_artifacts(
plots: bool = True,
report: bool = True,
website: bool = True,
overwrite: bool = False,
) -> None:

complex_project_dir = os.path.join(
Expand Down Expand Up @@ -348,6 +349,7 @@ def generate_artifacts(
max_binding_free_energy=config.max_binding_free_energy,
cache_dir=cache_dir,
num_procs=num_procs,
overwrite=overwrite,
)

if plots:
Expand All @@ -357,6 +359,7 @@ def generate_artifacts(
timestamp=timestamp,
output_dir=output_dir,
num_procs=num_procs,
overwrite=overwrite,
)

if snapshots and report:
Expand Down
152 changes: 101 additions & 51 deletions fah_xchem/analysis/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _filter_inclusive(

def plot_relative_distribution(
relative_delta_fs: List[float], min_delta_f: float = -30, max_delta_f: float = 30
) -> None:
) -> plt.Figure:
"""
Plot the distribution of relative free energies
Expand All @@ -156,18 +156,20 @@ def plot_relative_distribution(
)
valid_relative_delta_fs_kcal = valid_relative_delta_fs * KT_KCALMOL

sns.displot(
valid_relative_delta_fs_kcal,
kind="kde",
rug=True,
color="hotpink",
fill=True,
rug_kws=dict(alpha=0.5),
label=f"$N={len(relative_delta_fs)}$",
)
fgrid = sns.displot(
valid_relative_delta_fs_kcal,
kind="kde",
rug=True,
color="hotpink",
fill=True,
rug_kws=dict(alpha=0.5),
label=f"$N={len(relative_delta_fs)}$",
)
plt.xlabel(r"Relative free energy to reference fragment / kcal mol$^{-1}$")
plt.legend()

return fgrid.fig


def plot_convergence(
complex_gens: List[int],
Expand Down Expand Up @@ -354,7 +356,7 @@ def plot_cumulative_distribution(
cmap: str = "PiYG",
n_bins: int = 100,
markers_kcal: List[float] = [-6, -5, -4, -3, -2, -1, 0, 1, 2],
) -> None:
) -> plt.Figure:
"""
Plot cumulative distribution of ligand affinities
Expand Down Expand Up @@ -389,7 +391,8 @@ def plot_cumulative_distribution(
x_span = X.max() - X.min()
C = [cm(((X.max() - x) / x_span)) for x in X]

plt.bar(X[:-1], Y, color=C, width=X[1] - X[0], edgecolor="k")
fig, ax = plt.subplots()
ax.bar(X[:-1], Y, color=C, width=X[1] - X[0], edgecolor="k")

for marker_kcal in markers_kcal:
n_below = (relative_delta_fs_kcal < marker_kcal).astype(int).sum()
Expand All @@ -405,6 +408,8 @@ def plot_cumulative_distribution(
plt.xlabel(r"Relative free energy to reference fragment / kcal mol$^{-1}$")
plt.ylabel("Cumulative $N$ ligands")

return fig


def _bootstrap(
gens: List[GenAnalysis],
Expand Down Expand Up @@ -482,8 +487,9 @@ def plot_bootstrapped_clones(
return fig


def _plot_updated_timestamp(timestamp: dt.datetime) -> None:
fig = plt.gcf()
def _plot_updated_timestamp(timestamp: dt.datetime, fig: plt.Figure = None) -> None:
if fig is None:
fig = plt.gcf()
fig.text(
0.5,
0.03,
Expand Down Expand Up @@ -522,10 +528,10 @@ def _save_table_pdf(path: str, name: str):
logging.warning("Failed to save pdf table")


@contextmanager
def save_plot(
path: str,
name: str,
fig: plt.Figure,
file_formats: Iterable[str] = ("png", "pdf"),
timestamp: Optional[dt.datetime] = None,
) -> Generator:
Expand All @@ -548,42 +554,63 @@ def save_plot(
Examples
--------
>>> with save_plot('example/plots', 'test_plot', 'png'):
>>> plt.plot(np.cos(np.linspace(-np.pi, np.pi)))
>>> plt.title("My cool plot")
>>> fig = plt.plot(np.cos(np.linspace(-np.pi, np.pi)))
>>> fig.title("My cool plot")
>>> save_plot('example/plots', 'test_plot', fig, 'png'):
"""
outfiles = [os.path.join(path, os.extsep.join([name, file_format]))
for file_format in file_formats]

try:
yield
if timestamp is not None:
fig.tight_layout(rect=(0, 0.05, 1, 1)) # leave space for timestamp
_plot_updated_timestamp(timestamp, fig)
else:
fig.tight_layout()

if timestamp is not None:
plt.tight_layout(rect=(0, 0.05, 1, 1)) # leave space for timestamp
_plot_updated_timestamp(timestamp)
else:
plt.tight_layout()
# Make sure the directory exists
os.makedirs(path, exist_ok=True)

for outfile in outfiles:
fig.savefig(
outfile,
transparent=True,
)

# Make sure the directory exists
os.makedirs(path, exist_ok=True)
plt.close(fig=fig)

for file_format in file_formats:
plt.savefig(
os.path.join(path, os.extsep.join([name, file_format])),
transparent=True,
)
finally:
plt.close()

def _plot_to_file_mapping(
path: str,
name: str,
file_formats: Iterable[str] = ("png", "pdf"),
) -> List:
return [os.path.join(path, os.extsep.join([name, file_format]))
for file_format in file_formats]


def generate_transformation_plots(
transformation: TransformationAnalysis, output_dir: str
transformation: TransformationAnalysis,
output_dir: str,
overwrite: bool = False,
):

run_id = transformation.transformation.run_id

plot_output_dir = os.path.join(output_dir, "transformations", f"RUN{run_id}")
save_transformation_plot = partial(
save_plot, path=os.path.join(output_dir, "transformations", f"RUN{run_id}")
save_plot, path=plot_output_dir
)

with save_transformation_plot(name="works"):
name = "works"
# check if output files all exist; if so, skip unless we are told not to
skip = False
if not overwrite:
outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name)
if all(map(os.path.exists, outfiles)):
skip = True


if not skip:
fig = plot_work_distributions(
complex_forward_works=[
work.forward
Expand All @@ -609,8 +636,17 @@ def generate_transformation_plots(
solvent_delta_f=transformation.solvent_phase.free_energy.delta_f.point,
)
fig.suptitle(f"RUN{run_id}")
save_transformation_plot(name=name, fig=fig)

with save_transformation_plot(name="convergence"):
name = "convergence"
# check if output files all exist; if so, skip unless we are told not to
skip = False
if not overwrite:
outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name)
if all(map(os.path.exists, outfiles)):
skip = True

if not skip:
# Filter to GENs for which free energy calculation is available
complex_gens = [
(gen.gen, gen.free_energy)
Expand All @@ -634,9 +670,17 @@ def generate_transformation_plots(
binding_delta_f_err=transformation.binding_free_energy.stderr,
)
fig.suptitle(f"RUN{run_id}")
save_transformation_plot(name=name, fig=fig)

with save_transformation_plot(name="bootstrapped-CLONEs"):
name = "bootstrapped-CLONEs"
# check if output files all exist; if so, skip unless we are told not to
skip = False
if not overwrite:
outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name)
if all(map(os.path.exists, outfiles)):
skip = True

if not skip:
# Gather CLONES per GEN for run
clones_per_gen = min(
[
Expand All @@ -658,13 +702,15 @@ def generate_transformation_plots(
n_gens=n_gens,
)
fig.suptitle(f"RUN{run_id}")
save_transformation_plot(name=name, fig=fig)


def generate_plots(
series: CompoundSeriesAnalysis,
timestamp: dt.datetime,
output_dir: str,
num_procs: Optional[int] = None,
overwrite: bool = False,
) -> None:
"""
Generate analysis plots in `output_dir`.
Expand Down Expand Up @@ -699,6 +745,11 @@ def generate_plots(
"As of" timestamp to render on plots
output_dir : str
Where to write plot files
overwrite : bool
If `True`, write over existing output files if present.
Otherwise, skip writing output files for a given transformation when already present.
Assumes that for a given `run_id` the output files do not ever change;
does *no* checking that files wouldn't be different if inputs for a given `run_id` have changed.
"""
from rich.progress import track

Expand All @@ -717,25 +768,24 @@ def generate_plots(

# Summary plots

with save_summary_plot(
name="relative_fe_dist",
):
plot_relative_distribution(binding_delta_fs)
plt.title("Relative free energy")

with save_summary_plot(
name="cumulative_fe_dist",
):
plot_cumulative_distribution(binding_delta_fs)
plt.title("Cumulative distribution")
# we always regenerate these, since they concern all data
fig = plot_relative_distribution(binding_delta_fs)
plt.title("Relative free energy")
save_summary_plot(name="relative_fe_dist", fig=fig)

fig = plot_cumulative_distribution(binding_delta_fs)
plt.title("Cumulative distribution")
save_summary_plot(name="cumulative_fe_dist", fig=fig)

with _save_table_pdf(path=output_dir, name="poor_complex_convergence_fe_table"):
plot_poor_convergence_fe_table(series.transformations)

# Transformation-level plots

generate_transformation_plots_partial = partial(
generate_transformation_plots, output_dir=output_dir
generate_transformation_plots,
output_dir=output_dir,
overwrite=overwrite,
)

with multiprocessing.Pool(num_procs) as pool:
Expand Down
Loading

0 comments on commit 8c0d75f

Please sign in to comment.