diff --git a/scripts/smart-dispatch b/scripts/smart-dispatch index 6c1c67c..8a15b8f 100755 --- a/scripts/smart-dispatch +++ b/scripts/smart-dispatch @@ -209,7 +209,7 @@ def parse_arguments(): parser.add_argument('-p', '--pool', type=int, help="Number of workers that will be consuming commands. Default: Nb commands") parser.add_argument('--pbsFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of PBS flags. Ex:--pbsFlags="-lfeature=k80 -t0-4"') - parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="-qos=high --output=file.out"') + parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="--qos=high --ofile.out"') subparsers = parser.add_subparsers(dest="mode") launch_parser = subparsers.add_parser('launch', help="Launch jobs.") diff --git a/smartdispatch/job_generator.py b/smartdispatch/job_generator.py index 8444f2b..b219ea1 100644 --- a/smartdispatch/job_generator.py +++ b/smartdispatch/job_generator.py @@ -79,11 +79,13 @@ def add_sbatch_flags(self, flags): for flag in flags: split = flag.find('=') if flag.startswith('--'): - options[flag[2:split]] = flag[split+1:] - elif flag.startswith('-'): - options[flag[1:split]] = flag[split+1:] + if split == -1: + raise ValueError("Invalid SBATCH flag ({})".format(flag)) + options[flag[:split].lstrip("-")] = flag[split+1:] + elif flag.startswith('-') and split == -1: + options[flag[1:2]] = flag[2:] else: - raise ValueError("Invalid SBATCH flag ({})".format(flag)) + raise ValueError("Invalid SBATCH flag ({}, is it a PBS flag?)".format(flag)) for pbs in self.pbs_list: pbs.add_sbatch_options(**options) diff --git a/smartdispatch/pbs.py b/smartdispatch/pbs.py index ef60efc..df93028 100644 --- a/smartdispatch/pbs.py +++ b/smartdispatch/pbs.py @@ -178,7 +178,10 @@ def __str__(self): pbs += ["#PBS -l {0}={1}".format(resource_name, resource_value)] for option_name, option_value in self.sbatch_options.items(): - pbs += ["#SBATCH {0}={1}".format(option_name, option_value)] + if option_name.startswith('--'): + pbs += ["#SBATCH {0}={1}".format(option_name, option_value)] + else: + pbs += ["#SBATCH {0} {1}".format(option_name, option_value)] pbs += ["\n# Modules #"] for module in self.modules: diff --git a/smartdispatch/tests/test_job_generator.py b/smartdispatch/tests/test_job_generator.py index 22c4b0b..ce36d25 100644 --- a/smartdispatch/tests/test_job_generator.py +++ b/smartdispatch/tests/test_job_generator.py @@ -13,6 +13,7 @@ class TestJobGenerator(object): pbs_flags = ['-lfeature=k80', '-lwalltime=42:42', '-lnodes=6:gpus=66', '-m', '-A123-asd-11', '-t10,20,30'] + sbatch_flags = ['--qos=high', '--output=file.out', '-Cminmemory'] def setUp(self): self.testing_dir = tempfile.mkdtemp() @@ -129,6 +130,32 @@ def test_add_pbs_flags_invalid(self): def test_add_pbs_flags_invalid_resource(self): assert_raises(ValueError, self._test_add_pbs_flags, '-l weeee') + def _test_add_sbatch_flags(self, flags): + job_generator = JobGenerator(self.queue, self.commands) + job_generator.add_sbatch_flags(flags) + options = [] + + for flag in flags: + if flag.startswith('--'): + options += [flag] + elif flag.startswith('-'): + options += [(flag[:2] + ' ' + flag[2:]).strip()] + + for pbs in job_generator.pbs_list: + pbs_str = pbs.__str__() + for flag in options: + assert_equal(pbs_str.count(flag), 1) + + def test_add_sbatch_flags(self): + for flag in self.sbatch_flags: + yield self._test_add_sbatch_flags, [flag] + + yield self._test_add_sbatch_flags, [flag] + + def test_add_sbatch_flag_invalid(self): + invalid_flags = ["--qos high", "gpu", "-lfeature=k80"] + for flag in invalid_flags: + assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high") class TestGuilliminQueue(object): @@ -244,6 +271,28 @@ def test_pbs_split_2_job_nb_commands(self): assert_true("ppn=6" in str(self.pbs8[0])) assert_true("ppn=2" in str(self.pbs8[1])) +class TestSlurmQueue(object): + + def setUp(self): + self.walltime = "10:00" + self.cores = 42 + self.mem_per_node = 32 + self.nb_cores_per_node = 1 + self.nb_gpus_per_node = 2 + self.queue = Queue("slurm", "mila", self.walltime, self.nb_cores_per_node, self.nb_gpus_per_node, self.mem_per_node) + + self.commands = ["echo 1", "echo 2", "echo 3", "echo 4"] + job_generator = SlurmJobGenerator(self.queue, self.commands) + self.pbs = job_generator.pbs_list + + def test_ppn_ncpus(self): + assert_true("ppn" not in str(self.pbs[0])) + assert_true("ncpus" in str(self.pbs[0])) + + def test_gpus_naccelerators(self): + assert_true("gpus" not in str(self.pbs[0])) + assert_true("naccelerators" in str(self.pbs[0])) + class TestJobGeneratorFactory(object): def setUp(self): diff --git a/smartdispatch/utils.py b/smartdispatch/utils.py index 3c598ad..4839c00 100644 --- a/smartdispatch/utils.py +++ b/smartdispatch/utils.py @@ -115,6 +115,7 @@ def detect_cluster(): output = Popen(["qstat", "-B"], stdout=PIPE).communicate()[0] except OSError: # If qstat is not available we assume that the cluster is unknown. + # TODO: handle MILA + CEDAR + GRAHAM cluster_name = get_slurm_cluster_name() return None # Get server name from status