From 5d7c60ab799c6234537113a30907fafaf0085b42 Mon Sep 17 00:00:00 2001 From: dp-yuanyn Date: Wed, 6 Mar 2024 19:13:57 +0800 Subject: [PATCH] fix:tool app add default args --- .github/workflows/ci_test_tools.yml | 11 +++-- .gitignore | 3 +- .../src/unidock_tools/application/mcdock.py | 48 ++++++++++++++++--- .../application/unidock_pipeline.py | 35 ++++++++++++-- 4 files changed, 82 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci_test_tools.yml b/.github/workflows/ci_test_tools.yml index a0371134..6ea7e878 100644 --- a/.github/workflows/ci_test_tools.yml +++ b/.github/workflows/ci_test_tools.yml @@ -25,8 +25,9 @@ jobs: pip install flake8 - name: Run flake8 formating + working-directory: ./unidock_tools run: | - flake8 unidock_tools/unidock_tools --exit-zero + flake8 src --exit-zero pyright: continue-on-error: true @@ -48,7 +49,7 @@ jobs: id: pyright_check working-directory: ./unidock_tools run: | - pyright + pyright src tests: if: ${{ always() }} @@ -84,10 +85,12 @@ jobs: pip install . - name: run unit-test + working-directory: ./unidock_tools run: | pip install pytest pytest-cov - pytest unidock_tools/tests/ut -vv --cov --cov-report term + pytest tests/ut -vv --cov --cov-report term - name: run application e2e test + working-directory: ./unidock_tools run: | - pytest unidock_tools/tests/applications -vv --cov --cov-report term + pytest tests/applications -vv --cov --cov-report term diff --git a/.gitignore b/.gitignore index a5f84b5b..c2e59666 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,4 @@ unidock/example/screening_test/unidock_root unidock/example/screening_test/test_dock unidock_tools/dist unidock_tools/dist/* -unidock_tools/unidock_tools.egg-info -unidock_tools/unidock_tools.egg-info/* +*.egg-info diff --git a/unidock_tools/src/unidock_tools/application/mcdock.py b/unidock_tools/src/unidock_tools/application/mcdock.py index b6625bfc..0b7f500a 100644 --- a/unidock_tools/src/unidock_tools/application/mcdock.py +++ b/unidock_tools/src/unidock_tools/application/mcdock.py @@ -11,6 +11,41 @@ from .unidock_pipeline import UniDock +DEFAULT_ARGS = { + "receptor": None, + "ligands": None, + "ligand_index": None, + "gen_conf": False, + "max_num_confs_per_ligand": 200, + "min_rmsd": 0.3, + "center_x": None, + "center_y": None, + "center_z": None, + "size_x": 22.5, + "size_y": 22.5, + "size_z": 22.5, + "workdir": "mcdock_workdir", + "savedir": "mcdock_results", + "batch_size": 1200, + "scoring_function_rigid_docking": "vina", + "search_mode_rigid_docking": "", + "exhaustiveness_rigid_docking": 128, + "max_step_rigid_docking": 20, + "num_modes_rigid_docking": 3, + "refine_step_rigid_docking": 3, + "topn_rigid_docking": 100, + "scoring_function_local_refine": "vina", + "search_mode_local_refine": "", + "exhaustiveness_local_refine": 512, + "max_step_local_refine": 40, + "num_modes_local_refine": 1, + "refine_step_local_refine": 3, + "topn_local_refine": 1, + "seed": 181129, + "debug": False, +} + + class MultiConfDock(UniDock): def __init__(self, receptor: Path, @@ -69,6 +104,7 @@ def generate_conformation(self, def main(args: dict): + args = {**DEFAULT_ARGS, **args} if args["debug"]: logging.getLogger().setLevel("DEBUG") @@ -154,7 +190,7 @@ def get_parser() -> argparse.ArgumentParser: help="Receptor file in pdbqt format.") parser.add_argument("-l", "--ligands", type=lambda s: s.split(','), default=None, help="Ligand file in sdf format. Specify multiple files separated by commas.") - parser.add_argument("-i", "--ligand_index", type=str, default="", + parser.add_argument("-i", "--ligand_index", type=str, default=None, help="A text file containing the path of ligand files in sdf format.") parser.add_argument("-g", "--gen_conf", action="store_true", @@ -181,8 +217,8 @@ def get_parser() -> argparse.ArgumentParser: help="Working directory. Default: 'MultiConfDock'.") parser.add_argument("-sd", "--savedir", type=str, default="mcdock_results", help="Save directory. Default: 'MultiConfDock-Result'.") - parser.add_argument("-bs", "--batch_size", type=int, default=20, - help="Batch size for docking. Default: 20.") + parser.add_argument("-bs", "--batch_size", type=int, default=1200, + help="Batch size for docking. Default: 1200.") parser.add_argument("-sf_rd", "--scoring_function_rigid_docking", type=str, default="vina", @@ -200,7 +236,7 @@ def get_parser() -> argparse.ArgumentParser: type=int, default=3, help="Number of modes used in rigid docking. Default: 3.") parser.add_argument("-rs_rd", "--refine_step_rigid_docking", - type=int, default=5, + type=int, default=3, help="Refine step used in rigid docking. Default: 3.") parser.add_argument("-topn_rd", "--topn_rigid_docking", type=int, default=100, @@ -222,8 +258,8 @@ def get_parser() -> argparse.ArgumentParser: type=int, default=1, help="Number of modes used in local refine. Default: 1.") parser.add_argument("-rs_lr", "--refine_step_local_refine", - type=int, default=5, - help="Refine step used in local refine. Default: 5.") + type=int, default=3, + help="Refine step used in local refine. Default: 3.") parser.add_argument("-topn_lr", "--topn_local_refine", type=int, default=1, help="Top N results used in local refine. Default: 1.") diff --git a/unidock_tools/src/unidock_tools/application/unidock_pipeline.py b/unidock_tools/src/unidock_tools/application/unidock_pipeline.py index c24c8a3a..6e9394ad 100644 --- a/unidock_tools/src/unidock_tools/application/unidock_pipeline.py +++ b/unidock_tools/src/unidock_tools/application/unidock_pipeline.py @@ -16,6 +16,34 @@ from .base import Base +DEFAULT_ARGS = { + "receptor": None, + "ligands": None, + "ligand_index": None, + "center_x": None, + "center_y": None, + "center_z": None, + "size_x": 22.5, + "size_y": 22.5, + "size_z": 22.5, + "workdir": "docking_workdir", + "savedir": "docking_results", + "batch_size": 1200, + "scoring_function": "vina", + "search_mode": "", + "exhaustiveness": 128, + "max_step": 20, + "num_modes": 3, + "refine_step": 3, + "energy_range": 3.0, + "topn": 100, + "score_only": False, + "local_only": False, + "seed": 181129, + "debug": False, +} + + class UniDock(Base): def __init__(self, receptor: Path, @@ -177,17 +205,18 @@ def save_results(self, save_dir: Union[str, os.PathLike] = ""): def main(args: dict): + args = {**DEFAULT_ARGS, **args} workdir = Path(args["workdir"]).resolve() savedir = Path(args["savedir"]).resolve() ligands = [] - if args["ligands"]: + if args.get("ligands"): for lig in args["ligands"]: if not Path(lig).exists(): logging.error(f"Cannot find {lig}") continue ligands.append(Path(lig).resolve()) - if args["ligand_index"]: + if args.get("ligand_index"): with open(args["ligand_index"], "r") as f: index_content = f.read() index_lines1 = [Path(line.strip()).resolve() for line in index_content.split("\n") @@ -243,7 +272,7 @@ def get_parser() -> argparse.ArgumentParser: help="Receptor file in pdbqt format.") parser.add_argument("-l", "--ligands", type=lambda s: s.split(','), default=None, help="Ligand file in sdf format. Specify multiple files separated by commas.") - parser.add_argument("-i", "--ligand_index", type=str, default="", + parser.add_argument("-i", "--ligand_index", type=str, default=None, help="A text file containing the path of ligand files in sdf format.") parser.add_argument("-cx", "--center_x", type=float, required=True,