Skip to content

Commit

Permalink
Tweaks to concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
hellais committed Sep 4, 2024
1 parent 370c309 commit 31decb7
Showing 1 changed file with 96 additions and 102 deletions.
198 changes: 96 additions & 102 deletions oonipipeline/src/oonipipeline/temporal/activities/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,6 @@ class MakeObservationsParams:
bucket_date: str


def write_observations_to_db(
msmt: SupportedDataformats,
netinfodb: NetinfoDB,
db: ClickhouseConnection,
bucket_date: str,
):
for observations in measurement_to_observations(
msmt=msmt, netinfodb=netinfodb, bucket_date=bucket_date
):
if len(observations) == 0:
continue

column_names = [f.name for f in dataclasses.fields(observations[0])]
table_name, rows = make_db_rows(
bucket_date=bucket_date,
dc_list=observations,
column_names=column_names,
)
db.write_rows(table_name=table_name, rows=rows, column_names=column_names)


FileEntryBatchType = Tuple[str, str, str, int]


Expand All @@ -80,7 +59,58 @@ class MakeObservationsFileEntryBatch:
fast_fail: bool


def make_observations_for_file_entry_batch(
def make_observations_for_file_entry(
clickhouse: str,
data_dir: pathlib.Path,
bucket_date: str,
bucket_name: str,
s3path: str,
ext: str,
ccs: set,
fast_fail: bool,
write_batch_size: int,
):
with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db:
netinfodb = NetinfoDB(datadir=data_dir, download=False)
for msmt_dict in stream_measurements(
bucket_name=bucket_name, s3path=s3path, ext=ext
):
# Legacy cans don't allow us to pre-filter on the probe_cc, so
# we need to check for probe_cc consistency in here.
if ccs and msmt_dict["probe_cc"] not in ccs:
continue
msmt = None
try:
msmt = load_measurement(msmt_dict)
if not msmt.test_keys:
log.error(
f"measurement with empty test_keys: ({msmt.measurement_uid})",
exc_info=True,
)
continue
obs_tuple = measurement_to_observations(
msmt=msmt,
netinfodb=netinfodb,
bucket_date=bucket_date,
)
for obs_list in obs_tuple:
db.write_table_model_rows(obs_list)
idx += 1
except Exception as exc:
msmt_str = msmt_dict.get("report_id", None)
if msmt:
msmt_str = msmt.measurement_uid
log.error(f"failed at idx: {idx} ({msmt_str})", exc_info=True)
failure_count += 1

if fast_fail:
db.close()
raise exc
log.debug(f"done processing file s3://{bucket_name}/{s3path}")
return idx, failure_count


async def make_observations_for_file_entry_batch(
file_entry_batch: List[FileEntryBatchType],
bucket_date: str,
probe_cc: List[str],
Expand All @@ -89,72 +119,44 @@ def make_observations_for_file_entry_batch(
write_batch_size: int,
fast_fail: bool = False,
) -> int:
netinfodb = NetinfoDB(datadir=data_dir, download=False)

tracer = trace.get_tracer(__name__)
loop = asyncio.get_running_loop()

tbatch = PerfTimer()
total_failure_count = 0
with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db:
ccs = ccs_set(probe_cc)
idx = 0
for bucket_name, s3path, ext, fe_size in file_entry_batch:
failure_count = 0
# Nest the traced span within the current span
with tracer.start_span("MakeObservations:stream_file_entry") as span:
log.debug(f"processing file s3://{bucket_name}/{s3path}")
t = PerfTimer()
try:
for msmt_dict in stream_measurements(
bucket_name=bucket_name, s3path=s3path, ext=ext
):
# Legacy cans don't allow us to pre-filter on the probe_cc, so
# we need to check for probe_cc consistency in here.
if ccs and msmt_dict["probe_cc"] not in ccs:
continue
msmt = None
try:
msmt = load_measurement(msmt_dict)
if not msmt.test_keys:
log.error(
f"measurement with empty test_keys: ({msmt.measurement_uid})",
exc_info=True,
)
continue
obs_tuple = measurement_to_observations(
msmt=msmt,
netinfodb=netinfodb,
bucket_date=bucket_date,
)
for obs_list in obs_tuple:
db.write_table_model_rows(obs_list)
idx += 1
except Exception as exc:
msmt_str = msmt_dict.get("report_id", None)
if msmt:
msmt_str = msmt.measurement_uid
log.error(
f"failed at idx: {idx} ({msmt_str})", exc_info=True
)
failure_count += 1

if fast_fail:
db.close()
raise exc
log.debug(f"done processing file s3://{bucket_name}/{s3path}")
except Exception as exc:
log.error(
f"failed to stream measurements from s3://{bucket_name}/{s3path}"
)
log.error(exc)
# TODO(art): figure out if the rate of these metrics is too
# much. For each processed file a telemetry event is generated.
span.set_attribute("kb_per_sec", fe_size / 1024 / t.s)
span.set_attribute("fe_size", fe_size)
span.set_attribute("failure_count", failure_count)
span.add_event(f"s3_path: s3://{bucket_name}/{s3path}")
total_failure_count += failure_count
ccs = ccs_set(probe_cc)
total_measurement_count = 0
awaitables = []
for bucket_name, s3path, ext, fe_size in file_entry_batch:
failure_count = 0
log.debug(f"processing file s3://{bucket_name}/{s3path}")
awaitables.append(
loop.run_in_executor(
None,
functools.partial(
make_observations_for_file_entry,
clickhouse=clickhouse,
data_dir=data_dir,
bucket_date=bucket_date,
bucket_name=bucket_name,
s3path=s3path,
ext=ext,
fast_fail=fast_fail,
write_batch_size=write_batch_size,
ccs=ccs,
),
)
)

results = await asyncio.gather(*awaitables)
for measurement_count, failure_count in results:
total_measurement_count += measurement_count
total_failure_count += failure_count

return idx
log.info(
f"finished batch ({len(file_entry_batch)} entries) in {tbatch.s:.3f} seconds"
)
log.info(f"msmt/s: {total_measurement_count / tbatch.s}")
return total_measurement_count


ObservationBatches = TypedDict(
Expand Down Expand Up @@ -205,26 +207,18 @@ async def make_observations(params: MakeObservationsParams) -> MakeObservationsR
bucket_date=params.bucket_date,
),
)
awaitables = []
measurement_count = 0
for file_entry_batch in batches["batches"]:
awaitables.append(
loop.run_in_executor(
None,
functools.partial(
make_observations_for_file_entry_batch,
file_entry_batch=file_entry_batch,
bucket_date=params.bucket_date,
probe_cc=params.probe_cc,
data_dir=pathlib.Path(params.data_dir),
clickhouse=params.clickhouse,
write_batch_size=1_000_000,
fast_fail=False,
),
)
measurement_count += await make_observations_for_file_entry_batch(
file_entry_batch=file_entry_batch,
bucket_date=params.bucket_date,
probe_cc=params.probe_cc,
data_dir=pathlib.Path(params.data_dir),
clickhouse=params.clickhouse,
write_batch_size=1_000_000,
fast_fail=False,
)

measurement_count = sum(await asyncio.gather(*awaitables))

current_span.set_attribute("total_runtime_ms", tbatch.ms)
# current_span.set_attribute("total_failure_count", total_failure_count)

Expand Down

0 comments on commit 31decb7

Please sign in to comment.