From 02ceda23a251196bcfc1adbd80c1a5036547fd1f Mon Sep 17 00:00:00 2001 From: Jonathan Walker Date: Thu, 29 Aug 2024 14:03:45 -0400 Subject: [PATCH] feat: cleanup threadpool --- src/sasctl/core.py | 7 +++++++ tests/unit/test_pageiterator.py | 26 +++++++++++++------------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/sasctl/core.py b/src/sasctl/core.py index 3940ac2a..a26e2851 100644 --- a/src/sasctl/core.py +++ b/src/sasctl/core.py @@ -1529,6 +1529,13 @@ def __init__(self, obj, session=None, threads=4): # Store the current items to iterate over self._obj = obj + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._pool is not None: + self._pool.shutdown(wait=False, cancel_futures=True) + def __next__(self): if self._pool is None: self._pool = concurrent.futures.ThreadPoolExecutor( diff --git a/tests/unit/test_pageiterator.py b/tests/unit/test_pageiterator.py index 3050f986..4c49c7d2 100644 --- a/tests/unit/test_pageiterator.py +++ b/tests/unit/test_pageiterator.py @@ -62,16 +62,16 @@ def test_paging_required(paging): """Requests should be made to retrieve additional pages.""" obj, items, _ = paging - pager = PageIterator(obj) - init_count = pager._start - - for i, page in enumerate(pager): - for j, item in enumerate(page): - if i == 0: - item_idx = j - else: - # Account for initial page size not necessarily being same size - # as additional pages - item_idx = init_count + (i - 1) * pager._limit + j - target = RestObj(items[item_idx]) - assert item.name == target.name + with PageIterator(obj) as pager: + init_count = pager._start + + for i, page in enumerate(pager): + for j, item in enumerate(page): + if i == 0: + item_idx = j + else: + # Account for initial page size not necessarily being same size + # as additional pages + item_idx = init_count + (i - 1) * pager._limit + j + target = RestObj(items[item_idx]) + assert item.name == target.name