diff --git a/pytroll_collectors/segments.py b/pytroll_collectors/segments.py index a9b68ff..fabfab8 100644 --- a/pytroll_collectors/segments.py +++ b/pytroll_collectors/segments.py @@ -37,10 +37,11 @@ import datetime as dt import logging.handlers +import os +import signal from abc import ABCMeta, abstractmethod from collections import OrderedDict from enum import Enum -import os import trollsift from posttroll import message as pmessage @@ -610,6 +611,7 @@ def __init__(self, config): self._group_by_minutes = self._config.get('group_by_minutes', None) self._loop = False + self._sigterm_caught = False self._providing_server = self._config.get('providing_server') self._is_first_message_after_start = True @@ -711,9 +713,10 @@ def _collect_publisher_config(self): def run(self): """Run SegmentGatherer.""" self._setup_messaging() + signal.signal(signal.SIGTERM, self._handle_sigterm) self._loop = True - while self._loop: + while self._keep_running(): self.triage_slots() # Check listener for new messages @@ -722,8 +725,7 @@ def run(self): except AttributeError: msg = self._listener.queue.get(True, 1) except KeyboardInterrupt: - self.stop() - continue + break except Empty: continue @@ -733,6 +735,16 @@ def run(self): continue logger.info("New message received: %s", str(msg)) self.process(msg) + self.stop() + + def _handle_sigterm(self, signum, frame): + logging.info("Caught SIGTERM, shutting down when all collections are finished.") + self._sigterm_caught = True + + def _keep_running(self): + if not self._loop or (self._sigterm_caught and not self.slots): + return False + return True def triage_slots(self): """Check if there are slots ready for publication.""" diff --git a/pytroll_collectors/tests/test_fsspec_to_message.py b/pytroll_collectors/tests/test_fsspec_to_message.py index 0444fbd..8e3edb3 100644 --- a/pytroll_collectors/tests/test_fsspec_to_message.py +++ b/pytroll_collectors/tests/test_fsspec_to_message.py @@ -121,8 +121,8 @@ def create_files_to_pack(self, tmp_path): @pytest.mark.parametrize( ("packing", "create_packfile", "filesystem_class"), [ - ("tar", create_tar_file, "fsspec.implementations.tar.TarFileSystem"), - ("zip", create_zip_file, "fsspec.implementations.zip.ZipFileSystem"), + ("tar", create_tar_file, "fsspec.implementations.tar:TarFileSystem"), + ("zip", create_zip_file, "fsspec.implementations.zip:ZipFileSystem"), ] ) def test_pack_file_extract(self, packing, create_packfile, filesystem_class, tmp_path): @@ -153,8 +153,8 @@ def test_pack_file_extract(self, packing, create_packfile, filesystem_class, tmp @pytest.mark.parametrize( ("packing", "create_packfile", "filesystem_class"), [ - ("tar", create_tar_file, "fsspec.implementations.tar.TarFileSystem"), - ("zip", create_zip_file, "fsspec.implementations.zip.ZipFileSystem"), + ("tar", create_tar_file, "fsspec.implementations.tar:TarFileSystem"), + ("zip", create_zip_file, "fsspec.implementations.zip:ZipFileSystem"), ] ) def test_pack_local_file_extract(self, packing, create_packfile, filesystem_class, tmp_path): @@ -184,8 +184,8 @@ def test_pack_local_file_extract(self, packing, create_packfile, filesystem_clas @pytest.mark.parametrize( ("packing", "create_packfile", "filesystem_class"), [ - ("tar", create_tar_file, "fsspec.implementations.tar.TarFileSystem"), - ("zip", create_zip_file, "fsspec.implementations.zip.ZipFileSystem"), + ("tar", create_tar_file, "fsspec.implementations.tar:TarFileSystem"), + ("zip", create_zip_file, "fsspec.implementations.zip:ZipFileSystem"), ] ) def test_pack_local_file_extract_filesystem(self, packing, create_packfile, filesystem_class, tmp_path): @@ -211,8 +211,8 @@ def check_filesystem_is_understood_by_fsspec(self, filesystem_info): @pytest.mark.parametrize( ("packing", "create_packfile", "filesystem_class"), [ - ("tar", create_tar_file, "fsspec.implementations.tar.TarFileSystem"), - ("zip", create_zip_file, "fsspec.implementations.zip.ZipFileSystem"), + ("tar", create_tar_file, "fsspec.implementations.tar:TarFileSystem"), + ("zip", create_zip_file, "fsspec.implementations.zip:ZipFileSystem"), ] ) def test_pack_local_file_extract_with_custom_options(self, packing, create_packfile, filesystem_class, tmp_path): diff --git a/pytroll_collectors/tests/test_segments.py b/pytroll_collectors/tests/test_segments.py index c56874c..de10f06 100644 --- a/pytroll_collectors/tests/test_segments.py +++ b/pytroll_collectors/tests/test_segments.py @@ -747,6 +747,57 @@ def test_listener_use_first_nameserver(self): self.msg0deg._setup_listener() assert_messaging(None, None, None, 'localhost', None, ListenerContainer) + def test_sigterm(self): + """Test that SIGTERM signal is handled.""" + import os + import signal + import time + from multiprocessing import Process + + with patch('pytroll_collectors.segments.ListenerContainer'): + col = SegmentGatherer(CONFIG_SINGLE) + proc = Process(target=col.run) + proc.start() + time.sleep(1) + os.kill(proc.pid, signal.SIGTERM) + proc.join() + + assert proc.exitcode == 0 + + def test_sigterm_nonempty_slots(self): + """Test that SIGTERM signal is handled properly when there are active slots present.""" + import os + import signal + import time + from multiprocessing import Process + + with patch('pytroll_collectors.segments.ListenerContainer'): + with patch('pytroll_collectors.segments.SegmentGatherer.triage_slots', + new=_fake_triage_slots): + col = SegmentGatherer(CONFIG_SINGLE) + proc = Process(target=col.run) + proc.start() + time.sleep(1) + tic = time.time() + os.kill(proc.pid, signal.SIGTERM) + proc.join() + + assert proc.exitcode == 0 + # Triage after the kill signal takes 1 s + assert time.time() - tic > 1. + + +def _fake_triage_slots(self): + """Fake the triage_slots() method. + + The fake triage adds a new slot if SIGTERM has not been caught, and removes it when the signal comes. + """ + import time + self.slots["foo"] = "bar" + if self._sigterm_caught: + del self.slots["foo"] + time.sleep(1) + def _get_message_from_metadata_and_patterns(mda, patterns): fake_message = FakeMessage(mda) diff --git a/setup.cfg b/setup.cfg index 2bf78ea..4d02abc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,3 +19,4 @@ tag_prefix = v omit = pytroll_collectors/_version.py versioneer.py +relative_files = True