diff --git a/benchmarks/stargan/stargan/main.py b/benchmarks/stargan/stargan/main.py index 4d5e0e158..39fdfca19 100644 --- a/benchmarks/stargan/stargan/main.py +++ b/benchmarks/stargan/stargan/main.py @@ -1,4 +1,3 @@ -import json import os import argparse from solver import Solver @@ -197,29 +196,21 @@ def main(config): parser.add_argument("--mode", type=str, default="train", choices=["train", "test"]) parser.add_argument("--use_tensorboard", type=str2bool, default=False) - # try: - # mbconfig = json.loads(os.environ["MILABENCH_CONFIG"]) - # datadir = mbconfig["dirs"]["extra"] - # except: - # pass - - datadir = "/tmp/milabench/cuda/results/data" - # Directories. parser.add_argument("--celeba_image_dir", type=str, default="data/celeba/images") parser.add_argument( "--attr_path", type=str, default="data/celeba/list_attr_celeba.txt" ) parser.add_argument("--rafd_image_dir", type=str, default="data/RaFD/train") - parser.add_argument("--log_dir", type=str, default=os.path.join(datadir, "logs")) + parser.add_argument("--log_dir", type=str, default="/data/logs") parser.add_argument( - "--model_save_dir", type=str, default=os.path.join(datadir, "models") + "--model_save_dir", type=str, default="data/models" ) parser.add_argument( - "--sample_dir", type=str, default=os.path.join(datadir, "samples") + "--sample_dir", type=str, default="data/samples" ) parser.add_argument( - "--result_dir", type=str, default=os.path.join(datadir, "results") + "--result_dir", type=str, default="data/results" ) # Step size. diff --git a/config/base.yaml b/config/base.yaml index e41960dc7..4324f7ac9 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -496,8 +496,12 @@ stargan: --image_size: 512 --c_dim: 5 --batch_size: 16 - --dataset: "CelebA" + --dataset: "synth" --celeba_image_dir: "{milabench_data}" + --log_dir: "{milabench_extra}/logs" + --model_save_dir: "{milabench_extra}/models" + --sample_dir: "{milabench_extra}/samples" + --result_dir: "{milabench_extra}/results" super-slomo: inherits: _defaults diff --git a/milabench/_version.py b/milabench/_version.py index ceed029ad..f366a459d 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.0.6-126-g45b00ad" -__commit__ = "45b00adf57a249547ed737d59a509cd3734eb55e" -__date__ = "2024-06-06 13:55:11 -0400" +__tag__ = "v0.0.6-129-g4334571" +__commit__ = "433457173c9e040dc6772f06ee99110adc77ff71" +__date__ = "2024-06-06 16:03:46 -0400" diff --git a/milabench/sizer.py b/milabench/sizer.py index 5e4c42f07..5c206b7a8 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -291,6 +291,7 @@ def resolve_argv(pack, argv): context["milabench_data"] = pack.config.get("dirs", {}).get("data", None) context["milabench_cache"] = pack.config.get("dirs", {}).get("cache", None) + context["milabench_extra"] = pack.config.get("dirs", {}).get("extra", None) max_worker = 16 context["n_worker"] = min(context["cpu_per_gpu"], max_worker)