diff --git a/smartdispatch/job_generator.py b/smartdispatch/job_generator.py index b219ea1..af33e9c 100644 --- a/smartdispatch/job_generator.py +++ b/smartdispatch/job_generator.py @@ -80,7 +80,7 @@ def add_sbatch_flags(self, flags): split = flag.find('=') if flag.startswith('--'): if split == -1: - raise ValueError("Invalid SBATCH flag ({})".format(flag)) + raise ValueError("Invalid SBATCH flag ({}), no '=' character found' ".format(flag)) options[flag[:split].lstrip("-")] = flag[split+1:] elif flag.startswith('-') and split == -1: options[flag[1:2]] = flag[2:] diff --git a/smartdispatch/tests/test_job_generator.py b/smartdispatch/tests/test_job_generator.py index ce36d25..a82c43a 100644 --- a/smartdispatch/tests/test_job_generator.py +++ b/smartdispatch/tests/test_job_generator.py @@ -1,9 +1,14 @@ from nose.tools import assert_true, assert_false, assert_equal, assert_raises -import unittest - import os -import tempfile import shutil +import tempfile +import unittest + +try: + from mock import patch +except ImportError: + from unittest.mock import patch + from smartdispatch.queue import Queue from smartdispatch.job_generator import JobGenerator, job_generator_factory from smartdispatch.job_generator import HeliosJobGenerator, HadesJobGenerator @@ -155,7 +160,7 @@ def test_add_sbatch_flags(self): def test_add_sbatch_flag_invalid(self): invalid_flags = ["--qos high", "gpu", "-lfeature=k80"] for flag in invalid_flags: - assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high") + assert_raises(ValueError, self._test_add_sbatch_flags, flag) class TestGuilliminQueue(object): @@ -285,9 +290,15 @@ def setUp(self): job_generator = SlurmJobGenerator(self.queue, self.commands) self.pbs = job_generator.pbs_list + with patch.object(SlurmJobGenerator,'_add_cluster_specific_rules', side_effect=lambda: None): + dummy_generator = SlurmJobGenerator(self.queue, self.commands) + self.dummy_pbs = dummy_generator.pbs_list + def test_ppn_ncpus(self): assert_true("ppn" not in str(self.pbs[0])) assert_true("ncpus" in str(self.pbs[0])) + assert_true("ppn" in str(self.dummy_pbs[0])) + assert_true("ncpus" not in str(self.dummy_pbs[0])) def test_gpus_naccelerators(self): assert_true("gpus" not in str(self.pbs[0])) diff --git a/smartdispatch/tests/test_utils.py b/smartdispatch/tests/test_utils.py index 4eaef4e..295498b 100644 --- a/smartdispatch/tests/test_utils.py +++ b/smartdispatch/tests/test_utils.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- import unittest - -from smartdispatch import utils - +try: + from mock import patch + import mock +except ImportError: + from unittest.mock import patch + import unittest.mock from nose.tools import assert_equal, assert_true from numpy.testing import assert_array_equal +import subprocess +from smartdispatch import utils class PrintBoxedTests(unittest.TestCase): @@ -49,3 +54,40 @@ def test_slugify(): for arg, expected in testing_arguments: assert_equal(utils.slugify(arg), expected) + +command_output = """\ +Server Max Tot Que Run Hld Wat Trn Ext Com Status +---------------- --- --- --- --- --- --- --- --- --- ---------- +gpu-srv1.{} 0 1674 524 121 47 0 0 22 960 Idle +""" + +slurm_command = """\ + Cluster ControlHost ControlPort RPC Share GrpJobs GrpTRES GrpSubmit MaxJobs MaxTRES MaxSubmit MaxWall QOS Def QOS +---------- --------------- ------------ ----- --------- ------- ------------- --------- ------- ------------- --------- ----------- -------------------- --------- + {} 132.204.24.224 6817 7680 1 normal +""" + + +class ClusterIdentificationTest(unittest.TestCase): + + def test_detect_cluster(self): + server_name = ["hades", "m", "guil", "helios", "hades"] + clusters = ["hades", "mammouth", "guillimin", "helios"] + + for name, cluster in zip(server_name, clusters): + with patch('smartdispatch.utils.Popen') as mock_communicate: + mock_communicate.return_value.communicate.return_value = (command_output.format(name),) + self.assertEquals(utils.detect_cluster(), cluster) + + # def test_detect_mila_cluster(self): + # with patch('smartdispatch.utils.Popen') as mock_communicate: + # mock_communicate.return_value.communicate.side_effect = OSError + # self.assertIsNone(utils.detect_cluster()) + + def test_get_slurm_cluster_name(self): + clusters = ["graham", "cedar", "mila"] + + for cluster in clusters: + with patch('smartdispatch.utils.Popen') as mock_communicate: + mock_communicate.return_value.communicate.return_value = (slurm_command.format(cluster),) + self.assertEquals(utils.get_slurm_cluster_name(), cluster)