diff --git a/oonipipeline/src/oonipipeline/temporal/activities/observations.py b/oonipipeline/src/oonipipeline/temporal/activities/observations.py index 4bc63549..8660db15 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/observations.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/observations.py @@ -130,27 +130,26 @@ async def make_observations_for_file_entry_batch( ccs = ccs_set(probe_cc) total_measurement_count = 0 awaitables = [] - with concurrent.futures.ProcessPoolExecutor() as pool: - for bucket_name, s3path, ext, fe_size in file_entry_batch: - failure_count = 0 - log.debug(f"processing file s3://{bucket_name}/{s3path}") - awaitables.append( - loop.run_in_executor( - pool, - functools.partial( - make_observations_for_file_entry, - clickhouse=clickhouse, - data_dir=data_dir, - bucket_date=bucket_date, - bucket_name=bucket_name, - s3path=s3path, - ext=ext, - fast_fail=fast_fail, - write_batch_size=write_batch_size, - ccs=ccs, - ), - ) + for bucket_name, s3path, ext, fe_size in file_entry_batch: + failure_count = 0 + log.debug(f"processing file s3://{bucket_name}/{s3path}") + awaitables.append( + loop.run_in_executor( + None, + functools.partial( + make_observations_for_file_entry, + clickhouse=clickhouse, + data_dir=data_dir, + bucket_date=bucket_date, + bucket_name=bucket_name, + s3path=s3path, + ext=ext, + fast_fail=fast_fail, + write_batch_size=write_batch_size, + ccs=ccs, + ), ) + ) results = await asyncio.gather(*awaitables) for measurement_count, failure_count in results: diff --git a/oonipipeline/src/oonipipeline/temporal/workers.py b/oonipipeline/src/oonipipeline/temporal/workers.py index a74cd952..814c1cd0 100644 --- a/oonipipeline/src/oonipipeline/temporal/workers.py +++ b/oonipipeline/src/oonipipeline/temporal/workers.py @@ -28,7 +28,7 @@ log = logging.getLogger("oonipipeline.workers") -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor interrupt_event = asyncio.Event() @@ -51,7 +51,7 @@ async def worker_main( - temporal_config: TemporalConfig, max_workers: int, thread_pool: ThreadPoolExecutor + temporal_config: TemporalConfig, max_workers: int, executor: Executor ): client = await temporal_connect(temporal_config=temporal_config) async with Worker( @@ -59,7 +59,7 @@ async def worker_main( task_queue=TASK_QUEUE_NAME, workflows=WORKFLOWS, activities=ACTIVTIES, - activity_executor=thread_pool, + activity_executor=executor, max_concurrent_activities=max_workers, max_concurrent_workflow_tasks=max_workers, ): @@ -71,9 +71,10 @@ async def worker_main( def start_workers(temporal_config: TemporalConfig): max_workers = max(os.cpu_count() or 4, 4) log.info(f"starting workers with max_workers={max_workers}") - thread_pool = ThreadPoolExecutor(max_workers=max_workers + 2) + executor = ProcessPoolExecutor(max_workers=max_workers + 2) + loop = asyncio.new_event_loop() - loop.set_default_executor(thread_pool) + loop.set_default_executor(executor) # TODO(art): Investigate if we want to upgrade to python 3.12 and use this # instead # loop.set_task_factory(asyncio.eager_task_factory) @@ -82,11 +83,11 @@ def start_workers(temporal_config: TemporalConfig): worker_main( temporal_config=temporal_config, max_workers=max_workers, - thread_pool=thread_pool, + executor=executor, ) ) except KeyboardInterrupt: interrupt_event.set() loop.run_until_complete(loop.shutdown_asyncgens()) - thread_pool.shutdown(wait=True, cancel_futures=True) + executor.shutdown(wait=True, cancel_futures=True) log.info("shut down thread pool")