Skip to content

Commit

Permalink
Add stdout redirecting
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 committed Sep 11, 2023
1 parent 4ceeef7 commit 7f5e6c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@

6. Execute the autotune script in the container
```
docker exec autotune python /census-scvi/bin/autotune_scvi_v2.py --adata_path path_to_adata --batch_key batch_key --num_cpus num_cpus --num_gpus num_gpus --experiment_name homo_sapiens_scvi --save_dir /data
docker exec -d autotune python /census-scvi/bin/autotune_scvi_v2.py --adata_path path_to_adata --batch_key batch_key --num_cpus num_cpus --num_gpus num_gpus --experiment_name experiment_name --save_dir /data
```
25 changes: 23 additions & 2 deletions bin/autotune_scvi_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
import pathlib
import sys
from inspect import signature
from typing import Callable

Expand Down Expand Up @@ -30,6 +32,10 @@
"batch_size": 1024
}

SCHEDULER_KWARGS = {
"grace_period": 5,
}


def wrap_kwargs(fn: Callable) -> Callable:
"""Wrap a function to accept keyword arguments from the command line."""
Expand Down Expand Up @@ -72,6 +78,9 @@ def fit_tuner(
save_dir: str,
num_samples: int = 10,
max_epochs: int = 20,
scheduler: str = "asha",
searcher: str = "hyperopt",
scheduler_kwargs: dict = SCHEDULER_KWARGS,
):
return tuner.fit(
adata,
Expand All @@ -82,8 +91,9 @@ def fit_tuner(
use_defaults=False,
num_samples=num_samples,
max_epochs=max_epochs,
scheduler="asha",
searcher="hyperopt",
scheduler=scheduler,
searcher=searcher,
scheduler_kwargs=scheduler_kwargs,
resources={"cpu": num_cpus, "gpu": num_gpus},
seed=seed,
experiment_name=experiment_name,
Expand All @@ -101,6 +111,15 @@ def main(
experiment_name: str = "autotune_scvi_v2",
save_dir: str = "/data",
):
logging_dir = os.path.join(save_dir, experiment_name)
stdout_path = os.path.join(logging_dir, "stdout.log")
stderr_path = os.path.join(logging_dir, "stderr.log")
make_parents(stdout_path, stderr_path)
stdout_handle = open(stdout_path, "w")
stderr_handle = open(stderr_path, "w")
sys.stdout = stdout_handle
sys.stderr = stderr_handle

adata = load_anndata(adata_path)
setup_anndata(adata, batch_key)
tuner = setup_tuner()
Expand All @@ -117,6 +136,8 @@ def main(
save_dir,
)

stdout_handle.close()
stderr_handle.close()

if __name__ == "__main__":
main()

0 comments on commit 7f5e6c7

Please sign in to comment.