diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 5df60f65..ce1f403c 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -43,14 +43,11 @@ jobs: df -h ulimit -a - # More info on options: https://github.com/goanpeca/setup-miniconda - - uses: goanpeca/setup-miniconda@v1 + - name: Configure conda + uses: conda-incubator/setup-miniconda@v2 with: python-version: ${{ matrix.python-version }} environment-file: devtools/conda-envs/test_env.yaml - - channels: conda-forge,defaults,omnia - activate-environment: test auto-update-conda: true auto-activate-base: false diff --git a/README.md b/README.md index 276d8ced..8f691162 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Tools and infrastructure for automated compound discovery using Folding@home. Run transformation and compound free energy analysis, producing `results/analysis.json`: ``` sh -fah-xchem run-analysis +fah-xchem run-analysis \ --compound-series-file compound-series.json \ --config-file config.json \ --fah-projects-dir /path/to/projects/ \ diff --git a/fah_xchem/analysis/__init__.py b/fah_xchem/analysis/__init__.py index 09cade30..350ffb47 100644 --- a/fah_xchem/analysis/__init__.py +++ b/fah_xchem/analysis/__init__.py @@ -321,6 +321,7 @@ def generate_artifacts( plots: bool = True, report: bool = True, website: bool = True, + overwrite: bool = False, ) -> None: complex_project_dir = os.path.join( @@ -348,6 +349,7 @@ def generate_artifacts( max_binding_free_energy=config.max_binding_free_energy, cache_dir=cache_dir, num_procs=num_procs, + overwrite=overwrite, ) if plots: @@ -357,6 +359,7 @@ def generate_artifacts( timestamp=timestamp, output_dir=output_dir, num_procs=num_procs, + overwrite=overwrite, ) if snapshots and report: diff --git a/fah_xchem/analysis/plots.py b/fah_xchem/analysis/plots.py index c73a445a..c079d34b 100644 --- a/fah_xchem/analysis/plots.py +++ b/fah_xchem/analysis/plots.py @@ -138,7 +138,7 @@ def _filter_inclusive( def plot_relative_distribution( relative_delta_fs: List[float], min_delta_f: float = -30, max_delta_f: float = 30 -) -> None: +) -> plt.Figure: """ Plot the distribution of relative free energies @@ -156,18 +156,20 @@ def plot_relative_distribution( ) valid_relative_delta_fs_kcal = valid_relative_delta_fs * KT_KCALMOL - sns.displot( - valid_relative_delta_fs_kcal, - kind="kde", - rug=True, - color="hotpink", - fill=True, - rug_kws=dict(alpha=0.5), - label=f"$N={len(relative_delta_fs)}$", - ) + fgrid = sns.displot( + valid_relative_delta_fs_kcal, + kind="kde", + rug=True, + color="hotpink", + fill=True, + rug_kws=dict(alpha=0.5), + label=f"$N={len(relative_delta_fs)}$", + ) plt.xlabel(r"Relative free energy to reference fragment / kcal mol$^{-1}$") plt.legend() + return fgrid.fig + def plot_convergence( complex_gens: List[int], @@ -354,7 +356,7 @@ def plot_cumulative_distribution( cmap: str = "PiYG", n_bins: int = 100, markers_kcal: List[float] = [-6, -5, -4, -3, -2, -1, 0, 1, 2], -) -> None: +) -> plt.Figure: """ Plot cumulative distribution of ligand affinities @@ -389,7 +391,8 @@ def plot_cumulative_distribution( x_span = X.max() - X.min() C = [cm(((X.max() - x) / x_span)) for x in X] - plt.bar(X[:-1], Y, color=C, width=X[1] - X[0], edgecolor="k") + fig, ax = plt.subplots() + ax.bar(X[:-1], Y, color=C, width=X[1] - X[0], edgecolor="k") for marker_kcal in markers_kcal: n_below = (relative_delta_fs_kcal < marker_kcal).astype(int).sum() @@ -405,6 +408,8 @@ def plot_cumulative_distribution( plt.xlabel(r"Relative free energy to reference fragment / kcal mol$^{-1}$") plt.ylabel("Cumulative $N$ ligands") + return fig + def _bootstrap( gens: List[GenAnalysis], @@ -482,8 +487,9 @@ def plot_bootstrapped_clones( return fig -def _plot_updated_timestamp(timestamp: dt.datetime) -> None: - fig = plt.gcf() +def _plot_updated_timestamp(timestamp: dt.datetime, fig: plt.Figure = None) -> None: + if fig is None: + fig = plt.gcf() fig.text( 0.5, 0.03, @@ -522,10 +528,10 @@ def _save_table_pdf(path: str, name: str): logging.warning("Failed to save pdf table") -@contextmanager def save_plot( path: str, name: str, + fig: plt.Figure, file_formats: Iterable[str] = ("png", "pdf"), timestamp: Optional[dt.datetime] = None, ) -> Generator: @@ -548,42 +554,63 @@ def save_plot( Examples -------- - >>> with save_plot('example/plots', 'test_plot', 'png'): - >>> plt.plot(np.cos(np.linspace(-np.pi, np.pi))) - >>> plt.title("My cool plot") + >>> fig = plt.plot(np.cos(np.linspace(-np.pi, np.pi))) + >>> fig.title("My cool plot") + >>> save_plot('example/plots', 'test_plot', fig, 'png'): """ + outfiles = [os.path.join(path, os.extsep.join([name, file_format])) + for file_format in file_formats] - try: - yield + if timestamp is not None: + fig.tight_layout(rect=(0, 0.05, 1, 1)) # leave space for timestamp + _plot_updated_timestamp(timestamp, fig) + else: + fig.tight_layout() - if timestamp is not None: - plt.tight_layout(rect=(0, 0.05, 1, 1)) # leave space for timestamp - _plot_updated_timestamp(timestamp) - else: - plt.tight_layout() + # Make sure the directory exists + os.makedirs(path, exist_ok=True) + + for outfile in outfiles: + fig.savefig( + outfile, + transparent=True, + ) - # Make sure the directory exists - os.makedirs(path, exist_ok=True) + plt.close(fig=fig) - for file_format in file_formats: - plt.savefig( - os.path.join(path, os.extsep.join([name, file_format])), - transparent=True, - ) - finally: - plt.close() + +def _plot_to_file_mapping( + path: str, + name: str, + file_formats: Iterable[str] = ("png", "pdf"), +) -> List: + return [os.path.join(path, os.extsep.join([name, file_format])) + for file_format in file_formats] def generate_transformation_plots( - transformation: TransformationAnalysis, output_dir: str + transformation: TransformationAnalysis, + output_dir: str, + overwrite: bool = False, ): run_id = transformation.transformation.run_id + + plot_output_dir = os.path.join(output_dir, "transformations", f"RUN{run_id}") save_transformation_plot = partial( - save_plot, path=os.path.join(output_dir, "transformations", f"RUN{run_id}") + save_plot, path=plot_output_dir ) - with save_transformation_plot(name="works"): + name = "works" + # check if output files all exist; if so, skip unless we are told not to + skip = False + if not overwrite: + outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name) + if all(map(os.path.exists, outfiles)): + skip = True + + + if not skip: fig = plot_work_distributions( complex_forward_works=[ work.forward @@ -609,8 +636,17 @@ def generate_transformation_plots( solvent_delta_f=transformation.solvent_phase.free_energy.delta_f.point, ) fig.suptitle(f"RUN{run_id}") + save_transformation_plot(name=name, fig=fig) - with save_transformation_plot(name="convergence"): + name = "convergence" + # check if output files all exist; if so, skip unless we are told not to + skip = False + if not overwrite: + outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name) + if all(map(os.path.exists, outfiles)): + skip = True + + if not skip: # Filter to GENs for which free energy calculation is available complex_gens = [ (gen.gen, gen.free_energy) @@ -634,9 +670,17 @@ def generate_transformation_plots( binding_delta_f_err=transformation.binding_free_energy.stderr, ) fig.suptitle(f"RUN{run_id}") + save_transformation_plot(name=name, fig=fig) - with save_transformation_plot(name="bootstrapped-CLONEs"): + name = "bootstrapped-CLONEs" + # check if output files all exist; if so, skip unless we are told not to + skip = False + if not overwrite: + outfiles = _plot_to_file_mapping(path=plot_output_dir, name=name) + if all(map(os.path.exists, outfiles)): + skip = True + if not skip: # Gather CLONES per GEN for run clones_per_gen = min( [ @@ -658,6 +702,7 @@ def generate_transformation_plots( n_gens=n_gens, ) fig.suptitle(f"RUN{run_id}") + save_transformation_plot(name=name, fig=fig) def generate_plots( @@ -665,6 +710,7 @@ def generate_plots( timestamp: dt.datetime, output_dir: str, num_procs: Optional[int] = None, + overwrite: bool = False, ) -> None: """ Generate analysis plots in `output_dir`. @@ -699,6 +745,11 @@ def generate_plots( "As of" timestamp to render on plots output_dir : str Where to write plot files + overwrite : bool + If `True`, write over existing output files if present. + Otherwise, skip writing output files for a given transformation when already present. + Assumes that for a given `run_id` the output files do not ever change; + does *no* checking that files wouldn't be different if inputs for a given `run_id` have changed. """ from rich.progress import track @@ -717,25 +768,24 @@ def generate_plots( # Summary plots - with save_summary_plot( - name="relative_fe_dist", - ): - plot_relative_distribution(binding_delta_fs) - plt.title("Relative free energy") - - with save_summary_plot( - name="cumulative_fe_dist", - ): - plot_cumulative_distribution(binding_delta_fs) - plt.title("Cumulative distribution") + # we always regenerate these, since they concern all data + fig = plot_relative_distribution(binding_delta_fs) + plt.title("Relative free energy") + save_summary_plot(name="relative_fe_dist", fig=fig) + fig = plot_cumulative_distribution(binding_delta_fs) + plt.title("Cumulative distribution") + save_summary_plot(name="cumulative_fe_dist", fig=fig) + with _save_table_pdf(path=output_dir, name="poor_complex_convergence_fe_table"): plot_poor_convergence_fe_table(series.transformations) # Transformation-level plots generate_transformation_plots_partial = partial( - generate_transformation_plots, output_dir=output_dir + generate_transformation_plots, + output_dir=output_dir, + overwrite=overwrite, ) with multiprocessing.Pool(num_procs) as pool: diff --git a/fah_xchem/analysis/structures.py b/fah_xchem/analysis/structures.py index af297712..cb75cca6 100644 --- a/fah_xchem/analysis/structures.py +++ b/fah_xchem/analysis/structures.py @@ -23,6 +23,16 @@ from ..schema import TransformationAnalysis +def _transformation_to_file_mapping(output_dir, run_id, ligand): + fnames = [f"{ligand}_protein.pdb", + f"{ligand}_complex.pdb", + f"{ligand}_ligand.sdf"] + + outfiles = [os.path.join(output_dir, f"RUN{run_id}", f"{fname}") for fname in fnames] + + return outfiles + + def load_trajectory( project_dir: str, project_data_dir: str, run: int, clone: int, gen: int ) -> md.Trajectory: @@ -303,6 +313,7 @@ def generate_representative_snapshot( output_dir: str, max_binding_free_energy: Optional[float], cache_dir: Optional[str] = None, + overwrite: bool = False, ) -> None: r""" @@ -316,6 +327,8 @@ def generate_representative_snapshot( Parameters ---------- + transformation: TransformationAnalysis + The transformation record to operate on. project_dir : str Path to project directory (e.g. '/home/server/server2/projects/13422') project_data_dir : str @@ -330,11 +343,20 @@ def generate_representative_snapshot( Path where snapshots will be written cache_dir : str or None, optional If specified, cache relevant parts of "htf.npz" file in a local directory of this name + overwrite : bool + If `True`, write over existing output files if present. + Otherwise, skip writing output files for a given transformation when already present. + Assumes that for a given `run_id` the output files do not ever change; + does *no* checking that files wouldn't be different if inputs for a given `run_id` have changed. + Returns ------- None """ + # create output directory if not present + run_id = transformation.transformation.run_id + os.makedirs(os.path.join(output_dir, f"RUN{run_id}"), exist_ok=True) # TODO: Cache results and only update RUNs for which we have received new data @@ -355,6 +377,13 @@ def generate_representative_snapshot( ] for ligand in ["old", "new"]: + + # check if output files all exist; if so, skip unless we are told not to + if not overwrite: + outfiles = _transformation_to_file_mapping(output_dir, run_id, ligand) + if all(map(os.path.exists, outfiles)): + continue + if ligand == "old": gen_work = min(gen_works, key=lambda gen_work: gen_work[1].reverse) frame = 3 # TODO: Magic numbers @@ -362,7 +391,6 @@ def generate_representative_snapshot( gen_work = min(gen_works, key=lambda gen_work: gen_work[1].forward) frame = 1 # TODO: Magic numbers - run_id = transformation.transformation.run_id gen_analysis, workpair = gen_work @@ -381,7 +409,6 @@ def generate_representative_snapshot( # Write protein PDB name = f"{ligand}_protein" - os.makedirs(os.path.join(output_dir, f"RUN{run_id}"), exist_ok=True) sliced_snapshots["protein"].save( os.path.join(output_dir, f"RUN{run_id}", f"{name}.pdb") @@ -405,6 +432,7 @@ def generate_representative_snapshot( print(f'\nException occurred extracting snapshot from {project_dir} data {project_data_dir} run {run_id} clone {gen_work[1].clone} gen {gen_work[0].gen}') print(e) + def generate_representative_snapshots( transformations: List[TransformationAnalysis], project_dir: str, @@ -413,6 +441,7 @@ def generate_representative_snapshots( max_binding_free_energy: Optional[float], cache_dir: Optional[str], num_procs: Optional[int], + overwrite: bool = False, ) -> None: from rich.progress import track @@ -425,6 +454,7 @@ def generate_representative_snapshots( output_dir=output_dir, cache_dir=cache_dir, max_binding_free_energy=max_binding_free_energy, + overwrite=overwrite ), transformations, ) diff --git a/fah_xchem/app.py b/fah_xchem/app.py index 1b7442c5..254afd3d 100644 --- a/fah_xchem/app.py +++ b/fah_xchem/app.py @@ -144,6 +144,7 @@ def generate_artifacts( website: bool = True, log: str = "WARN", fragalysis_config: Optional[str] = None, + overwrite: bool = False, ) -> None: """ Given results of free energy analysis as JSON, generate analysis @@ -187,6 +188,11 @@ def generate_artifacts( Logging level fragalysis_config : str, optional File containing information for Fragalysis upload as JSON-encoded :class: ~`fah_xchem.schema.FragalysisConfig` + overwrite : bool + If `True`, write over existing output files if present. + Otherwise, skip writing output files for a given transformation when already present. + Assumes that for a given `run_id` the output files do not ever change; + does *no* checking that files wouldn't be different if inputs for a given `run_id` have changed. """ logging.basicConfig(level=getattr(logging, log.upper())) @@ -213,6 +219,7 @@ def generate_artifacts( report=report, website=website, fragalysis_config=fragalysis_config, + overwrite=overwrite, )