From 31decb7ffa6c0cf106d5deb45159e7211205c287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arturo=20Filast=C3=B2?= Date: Wed, 4 Sep 2024 18:01:10 -0400 Subject: [PATCH] Tweaks to concurrency --- .../temporal/activities/observations.py | 198 +++++++++--------- 1 file changed, 96 insertions(+), 102 deletions(-) diff --git a/oonipipeline/src/oonipipeline/temporal/activities/observations.py b/oonipipeline/src/oonipipeline/temporal/activities/observations.py index 4149be05..65820d13 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/observations.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/observations.py @@ -43,27 +43,6 @@ class MakeObservationsParams: bucket_date: str -def write_observations_to_db( - msmt: SupportedDataformats, - netinfodb: NetinfoDB, - db: ClickhouseConnection, - bucket_date: str, -): - for observations in measurement_to_observations( - msmt=msmt, netinfodb=netinfodb, bucket_date=bucket_date - ): - if len(observations) == 0: - continue - - column_names = [f.name for f in dataclasses.fields(observations[0])] - table_name, rows = make_db_rows( - bucket_date=bucket_date, - dc_list=observations, - column_names=column_names, - ) - db.write_rows(table_name=table_name, rows=rows, column_names=column_names) - - FileEntryBatchType = Tuple[str, str, str, int] @@ -80,7 +59,58 @@ class MakeObservationsFileEntryBatch: fast_fail: bool -def make_observations_for_file_entry_batch( +def make_observations_for_file_entry( + clickhouse: str, + data_dir: pathlib.Path, + bucket_date: str, + bucket_name: str, + s3path: str, + ext: str, + ccs: set, + fast_fail: bool, + write_batch_size: int, +): + with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db: + netinfodb = NetinfoDB(datadir=data_dir, download=False) + for msmt_dict in stream_measurements( + bucket_name=bucket_name, s3path=s3path, ext=ext + ): + # Legacy cans don't allow us to pre-filter on the probe_cc, so + # we need to check for probe_cc consistency in here. + if ccs and msmt_dict["probe_cc"] not in ccs: + continue + msmt = None + try: + msmt = load_measurement(msmt_dict) + if not msmt.test_keys: + log.error( + f"measurement with empty test_keys: ({msmt.measurement_uid})", + exc_info=True, + ) + continue + obs_tuple = measurement_to_observations( + msmt=msmt, + netinfodb=netinfodb, + bucket_date=bucket_date, + ) + for obs_list in obs_tuple: + db.write_table_model_rows(obs_list) + idx += 1 + except Exception as exc: + msmt_str = msmt_dict.get("report_id", None) + if msmt: + msmt_str = msmt.measurement_uid + log.error(f"failed at idx: {idx} ({msmt_str})", exc_info=True) + failure_count += 1 + + if fast_fail: + db.close() + raise exc + log.debug(f"done processing file s3://{bucket_name}/{s3path}") + return idx, failure_count + + +async def make_observations_for_file_entry_batch( file_entry_batch: List[FileEntryBatchType], bucket_date: str, probe_cc: List[str], @@ -89,72 +119,44 @@ def make_observations_for_file_entry_batch( write_batch_size: int, fast_fail: bool = False, ) -> int: - netinfodb = NetinfoDB(datadir=data_dir, download=False) - - tracer = trace.get_tracer(__name__) + loop = asyncio.get_running_loop() + tbatch = PerfTimer() total_failure_count = 0 - with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db: - ccs = ccs_set(probe_cc) - idx = 0 - for bucket_name, s3path, ext, fe_size in file_entry_batch: - failure_count = 0 - # Nest the traced span within the current span - with tracer.start_span("MakeObservations:stream_file_entry") as span: - log.debug(f"processing file s3://{bucket_name}/{s3path}") - t = PerfTimer() - try: - for msmt_dict in stream_measurements( - bucket_name=bucket_name, s3path=s3path, ext=ext - ): - # Legacy cans don't allow us to pre-filter on the probe_cc, so - # we need to check for probe_cc consistency in here. - if ccs and msmt_dict["probe_cc"] not in ccs: - continue - msmt = None - try: - msmt = load_measurement(msmt_dict) - if not msmt.test_keys: - log.error( - f"measurement with empty test_keys: ({msmt.measurement_uid})", - exc_info=True, - ) - continue - obs_tuple = measurement_to_observations( - msmt=msmt, - netinfodb=netinfodb, - bucket_date=bucket_date, - ) - for obs_list in obs_tuple: - db.write_table_model_rows(obs_list) - idx += 1 - except Exception as exc: - msmt_str = msmt_dict.get("report_id", None) - if msmt: - msmt_str = msmt.measurement_uid - log.error( - f"failed at idx: {idx} ({msmt_str})", exc_info=True - ) - failure_count += 1 - - if fast_fail: - db.close() - raise exc - log.debug(f"done processing file s3://{bucket_name}/{s3path}") - except Exception as exc: - log.error( - f"failed to stream measurements from s3://{bucket_name}/{s3path}" - ) - log.error(exc) - # TODO(art): figure out if the rate of these metrics is too - # much. For each processed file a telemetry event is generated. - span.set_attribute("kb_per_sec", fe_size / 1024 / t.s) - span.set_attribute("fe_size", fe_size) - span.set_attribute("failure_count", failure_count) - span.add_event(f"s3_path: s3://{bucket_name}/{s3path}") - total_failure_count += failure_count + ccs = ccs_set(probe_cc) + total_measurement_count = 0 + awaitables = [] + for bucket_name, s3path, ext, fe_size in file_entry_batch: + failure_count = 0 + log.debug(f"processing file s3://{bucket_name}/{s3path}") + awaitables.append( + loop.run_in_executor( + None, + functools.partial( + make_observations_for_file_entry, + clickhouse=clickhouse, + data_dir=data_dir, + bucket_date=bucket_date, + bucket_name=bucket_name, + s3path=s3path, + ext=ext, + fast_fail=fast_fail, + write_batch_size=write_batch_size, + ccs=ccs, + ), + ) + ) + + results = await asyncio.gather(*awaitables) + for measurement_count, failure_count in results: + total_measurement_count += measurement_count + total_failure_count += failure_count - return idx + log.info( + f"finished batch ({len(file_entry_batch)} entries) in {tbatch.s:.3f} seconds" + ) + log.info(f"msmt/s: {total_measurement_count / tbatch.s}") + return total_measurement_count ObservationBatches = TypedDict( @@ -205,26 +207,18 @@ async def make_observations(params: MakeObservationsParams) -> MakeObservationsR bucket_date=params.bucket_date, ), ) - awaitables = [] + measurement_count = 0 for file_entry_batch in batches["batches"]: - awaitables.append( - loop.run_in_executor( - None, - functools.partial( - make_observations_for_file_entry_batch, - file_entry_batch=file_entry_batch, - bucket_date=params.bucket_date, - probe_cc=params.probe_cc, - data_dir=pathlib.Path(params.data_dir), - clickhouse=params.clickhouse, - write_batch_size=1_000_000, - fast_fail=False, - ), - ) + measurement_count += await make_observations_for_file_entry_batch( + file_entry_batch=file_entry_batch, + bucket_date=params.bucket_date, + probe_cc=params.probe_cc, + data_dir=pathlib.Path(params.data_dir), + clickhouse=params.clickhouse, + write_batch_size=1_000_000, + fast_fail=False, ) - measurement_count = sum(await asyncio.gather(*awaitables)) - current_span.set_attribute("total_runtime_ms", tbatch.ms) # current_span.set_attribute("total_failure_count", total_failure_count)