Skip to content

Commit

Permalink
Updated tests using mock
Browse files Browse the repository at this point in the history
  • Loading branch information
aalitaiga committed Oct 10, 2017
1 parent c7bb250 commit a20405b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
2 changes: 1 addition & 1 deletion smartdispatch/job_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def add_sbatch_flags(self, flags):
split = flag.find('=')
if flag.startswith('--'):
if split == -1:
raise ValueError("Invalid SBATCH flag ({})".format(flag))
raise ValueError("Invalid SBATCH flag ({}), no '=' character found' ".format(flag))
options[flag[:split].lstrip("-")] = flag[split+1:]
elif flag.startswith('-') and split == -1:
options[flag[1:2]] = flag[2:]
Expand Down
19 changes: 15 additions & 4 deletions smartdispatch/tests/test_job_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
import unittest

import os
import tempfile
import shutil
import tempfile
import unittest

try:
from mock import patch
except ImportError:
from unittest.mock import patch

from smartdispatch.queue import Queue
from smartdispatch.job_generator import JobGenerator, job_generator_factory
from smartdispatch.job_generator import HeliosJobGenerator, HadesJobGenerator
Expand Down Expand Up @@ -155,7 +160,7 @@ def test_add_sbatch_flags(self):
def test_add_sbatch_flag_invalid(self):
invalid_flags = ["--qos high", "gpu", "-lfeature=k80"]
for flag in invalid_flags:
assert_raises(ValueError, self._test_add_sbatch_flags, "--qos high")
assert_raises(ValueError, self._test_add_sbatch_flags, flag)

class TestGuilliminQueue(object):

Expand Down Expand Up @@ -285,9 +290,15 @@ def setUp(self):
job_generator = SlurmJobGenerator(self.queue, self.commands)
self.pbs = job_generator.pbs_list

with patch.object(SlurmJobGenerator,'_add_cluster_specific_rules', side_effect=lambda: None):
dummy_generator = SlurmJobGenerator(self.queue, self.commands)
self.dummy_pbs = dummy_generator.pbs_list

def test_ppn_ncpus(self):
assert_true("ppn" not in str(self.pbs[0]))
assert_true("ncpus" in str(self.pbs[0]))
assert_true("ppn" in str(self.dummy_pbs[0]))

This comment has been minimized.

Copy link
@bouthilx

bouthilx Oct 10, 2017

Collaborator

It would be better to have those 2 before self.pbs. If those fail, we don't care about the first two (pbs), and if the first two fail (pbs), we want to know if it is because the last two (dummy_pbs).

assert_true("ncpus" not in str(self.dummy_pbs[0]))

def test_gpus_naccelerators(self):
assert_true("gpus" not in str(self.pbs[0]))
Expand Down
48 changes: 45 additions & 3 deletions smartdispatch/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
import unittest

from smartdispatch import utils

try:
from mock import patch
import mock
except ImportError:
from unittest.mock import patch
import unittest.mock

This comment has been minimized.

Copy link
@bouthilx

bouthilx Oct 10, 2017

Collaborator

Why is this needed, don't you just use patch?

from nose.tools import assert_equal, assert_true
from numpy.testing import assert_array_equal
import subprocess

from smartdispatch import utils

class PrintBoxedTests(unittest.TestCase):

Expand Down Expand Up @@ -49,3 +54,40 @@ def test_slugify():

for arg, expected in testing_arguments:
assert_equal(utils.slugify(arg), expected)

command_output = """\
Server Max Tot Que Run Hld Wat Trn Ext Com Status
---------------- --- --- --- --- --- --- --- --- --- ----------
gpu-srv1.{} 0 1674 524 121 47 0 0 22 960 Idle
"""

slurm_command = """\
Cluster ControlHost ControlPort RPC Share GrpJobs GrpTRES GrpSubmit MaxJobs MaxTRES MaxSubmit MaxWall QOS Def QOS
---------- --------------- ------------ ----- --------- ------- ------------- --------- ------- ------------- --------- ----------- -------------------- ---------
{} 132.204.24.224 6817 7680 1 normal
"""


class ClusterIdentificationTest(unittest.TestCase):

def test_detect_cluster(self):
server_name = ["hades", "m", "guil", "helios", "hades"]
clusters = ["hades", "mammouth", "guillimin", "helios"]

for name, cluster in zip(server_name, clusters):
with patch('smartdispatch.utils.Popen') as mock_communicate:
mock_communicate.return_value.communicate.return_value = (command_output.format(name),)
self.assertEquals(utils.detect_cluster(), cluster)

# def test_detect_mila_cluster(self):
# with patch('smartdispatch.utils.Popen') as mock_communicate:
# mock_communicate.return_value.communicate.side_effect = OSError
# self.assertIsNone(utils.detect_cluster())

def test_get_slurm_cluster_name(self):
clusters = ["graham", "cedar", "mila"]

for cluster in clusters:
with patch('smartdispatch.utils.Popen') as mock_communicate:

This comment has been minimized.

Copy link
@bouthilx

bouthilx Oct 10, 2017

Collaborator

Can't you patch the method directly?

This comment has been minimized.

Copy link
@aalitaiga

aalitaiga Oct 10, 2017

Author

I didn't manage to, any idea how?

This comment has been minimized.

Copy link
@bouthilx

bouthilx Oct 10, 2017

Collaborator

I played with the tests. If we mock communicate method only, then Popen still process the given code which raises an error since we don't have qstat and sacctmgr locally. This is why we need to mock both Popen and communicate method. I refactored a bit the tests and added a commit to your branch.

mock_communicate.return_value.communicate.return_value = (slurm_command.format(cluster),)
self.assertEquals(utils.get_slurm_cluster_name(), cluster)

0 comments on commit a20405b

Please sign in to comment.