Skip to content

Commit

Permalink
Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aalitaiga committed Oct 6, 2017
1 parent d7d0300 commit c7bb250
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/smart-dispatch
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def parse_arguments():

parser.add_argument('-p', '--pool', type=int, help="Number of workers that will be consuming commands. Default: Nb commands")
parser.add_argument('--pbsFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of PBS flags. Ex:--pbsFlags="-lfeature=k80 -t0-4"')
parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="-qos=high --output=file.out"')
parser.add_argument('--sbatchFlags', type=str, help='ADVANCED USAGE: Allow to pass a space seperated list of SBATCH flags. Ex:--sbatchFlags="--qos=high --ofile.out"')
subparsers = parser.add_subparsers(dest="mode")

launch_parser = subparsers.add_parser('launch', help="Launch jobs.")
Expand Down
10 changes: 6 additions & 4 deletions smartdispatch/job_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def add_sbatch_flags(self, flags):
for flag in flags:
split = flag.find('=')
if flag.startswith('--'):
options[flag[2:split]] = flag[split+1:]
elif flag.startswith('-'):
options[flag[1:split]] = flag[split+1:]
if split == -1:
raise ValueError("Invalid SBATCH flag ({})".format(flag))
options[flag[:split].lstrip("-")] = flag[split+1:]
elif flag.startswith('-') and split == -1:
options[flag[1:2]] = flag[2:]
else:
raise ValueError("Invalid SBATCH flag ({})".format(flag))
raise ValueError("Invalid SBATCH flag ({}, is it a PBS flag?)".format(flag))

for pbs in self.pbs_list:
pbs.add_sbatch_options(**options)
Expand Down
5 changes: 4 additions & 1 deletion smartdispatch/pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def __str__(self):
pbs += ["#PBS -l {0}={1}".format(resource_name, resource_value)]

for option_name, option_value in self.sbatch_options.items():
pbs += ["#SBATCH {0}={1}".format(option_name, option_value)]
if option_name.startswith('--'):
pbs += ["#SBATCH {0}={1}".format(option_name, option_value)]
else:
pbs += ["#SBATCH {0} {1}".format(option_name, option_value)]

pbs += ["\n# Modules #"]
for module in self.modules:
Expand Down
49 changes: 49 additions & 0 deletions smartdispatch/tests/test_job_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class TestJobGenerator(object):
pbs_flags = ['-lfeature=k80', '-lwalltime=42:42', '-lnodes=6:gpus=66', '-m', '-A123-asd-11', '-t10,20,30']
sbatch_flags = ['--qos=high', '--output=file.out', '-Cminmemory']

def setUp(self):
self.testing_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -129,6 +130,32 @@ def test_add_pbs_flags_invalid(self):
def test_add_pbs_flags_invalid_resource(self):
assert_raises(ValueError, self._test_add_pbs_flags, '-l weeee')

def _test_add_sbatch_flags(self, flags):
job_generator = JobGenerator(self.queue, self.commands)
job_generator.add_sbatch_flags(flags)
options = []

for flag in flags:
if flag.startswith('--'):
options += [flag]
elif flag.startswith('-'):
options += [(flag[:2] + ' ' + flag[2:]).strip()]

for pbs in job_generator.pbs_list:
pbs_str = pbs.__str__()
for flag in options:
assert_equal(pbs_str.count(flag), 1)

def test_add_sbatch_flags(self):
for flag in self.sbatch_flags:
yield self._test_add_sbatch_flags, [flag]

yield self._test_add_sbatch_flags, [flag]

def test_add_sbatch_flag_invalid(self):
invalid_flags = ["--qos high", "gpu", "-lfeature=k80"]
for flag in invalid_flags:
assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high")

class TestGuilliminQueue(object):

Expand Down Expand Up @@ -244,6 +271,28 @@ def test_pbs_split_2_job_nb_commands(self):
assert_true("ppn=6" in str(self.pbs8[0]))
assert_true("ppn=2" in str(self.pbs8[1]))

class TestSlurmQueue(object):

def setUp(self):
self.walltime = "10:00"
self.cores = 42
self.mem_per_node = 32
self.nb_cores_per_node = 1
self.nb_gpus_per_node = 2
self.queue = Queue("slurm", "mila", self.walltime, self.nb_cores_per_node, self.nb_gpus_per_node, self.mem_per_node)

self.commands = ["echo 1", "echo 2", "echo 3", "echo 4"]
job_generator = SlurmJobGenerator(self.queue, self.commands)
self.pbs = job_generator.pbs_list

def test_ppn_ncpus(self):
assert_true("ppn" not in str(self.pbs[0]))
assert_true("ncpus" in str(self.pbs[0]))

def test_gpus_naccelerators(self):
assert_true("gpus" not in str(self.pbs[0]))
assert_true("naccelerators" in str(self.pbs[0]))

class TestJobGeneratorFactory(object):

def setUp(self):
Expand Down
1 change: 1 addition & 0 deletions smartdispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def detect_cluster():
output = Popen(["qstat", "-B"], stdout=PIPE).communicate()[0]
except OSError:
# If qstat is not available we assume that the cluster is unknown.
# TODO: handle MILA + CEDAR + GRAHAM
cluster_name = get_slurm_cluster_name()
return None
# Get server name from status
Expand Down

0 comments on commit c7bb250

Please sign in to comment.