diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 96c0d9f9..77a4f992 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -19,13 +19,16 @@ from project.datamodules.image_classification.fashion_mnist import FashionMNISTDataModule from project.datamodules.image_classification.mnist import MNISTDataModule -from project.utils.env_vars import NETWORK_DIR +from project.utils.env_vars import NETWORK_DIR, SLURM_JOB_ID from project.utils.hydra_utils import get_outer_class logger = get_logger(__name__) IN_GITHUB_CI = "GITHUB_ACTIONS" in os.environ -IN_SELF_HOSTED_GITHUB_CI = IN_GITHUB_CI and "self-hosted" in os.environ.get("RUNNER_LABELS", "") +IN_SELF_HOSTED_GITHUB_CI = IN_GITHUB_CI and ( + "self-hosted" in os.environ.get("RUNNER_LABELS", "") + or (torch.cuda.is_available() and SLURM_JOB_ID is None) +) IN_GITHUB_CLOUD_CI = IN_GITHUB_CI and not IN_SELF_HOSTED_GITHUB_CI PARAM_WHEN_USED_MARK_NAME = "parametrize_when_used"