Skip to content

Commit

Permalink
Move observation generation parallelism back into activity
Browse files Browse the repository at this point in the history
  • Loading branch information
hellais committed Sep 4, 2024
1 parent 6d16ced commit 3481d4a
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 113 deletions.
111 changes: 78 additions & 33 deletions oonipipeline/src/oonipipeline/temporal/activities/observations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from dataclasses import dataclass
import dataclasses
import functools
from typing import Any, Dict, List, Sequence, Tuple, TypedDict
from oonidata.dataclient import (
ccs_set,
Expand Down Expand Up @@ -78,33 +80,22 @@ class MakeObservationsFileEntryBatch:
fast_fail: bool


@activity.defn
def make_observations_for_file_entry_batch(
params: MakeObservationsFileEntryBatch,
file_entry_batch: List[FileEntryBatchType],
bucket_date: str,
probe_cc: List[str],
data_dir: pathlib.Path,
clickhouse: str,
write_batch_size: int,
fast_fail: bool = False,
) -> int:
day = datetime.strptime(params.bucket_date, "%Y-%m-%d").date()
data_dir = pathlib.Path(params.data_dir)

netinfodb = NetinfoDB(datadir=data_dir, download=False)
tbatch = PerfTimer()

tracer = trace.get_tracer(__name__)

file_entry_batches, _ = list_file_entries_batches(
probe_cc=params.probe_cc,
test_name=params.test_name,
start_day=day,
end_day=day + timedelta(days=1),
)
file_entry_batch = file_entry_batches[params.batch_idx]

activity.heartbeat(f"running idx {params.batch_idx}")
total_failure_count = 0
current_span = trace.get_current_span()
with ClickhouseConnection(
params.clickhouse, write_batch_size=params.write_batch_size
) as db:
ccs = ccs_set(params.probe_cc)
with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db:
ccs = ccs_set(probe_cc)
idx = 0
for bucket_name, s3path, ext, fe_size in file_entry_batch:
failure_count = 0
Expand Down Expand Up @@ -133,13 +124,11 @@ def make_observations_for_file_entry_batch(
obs_tuple = measurement_to_observations(
msmt=msmt,
netinfodb=netinfodb,
bucket_date=params.bucket_date,
bucket_date=bucket_date,
)
for obs_list in obs_tuple:
db.write_table_model_rows(obs_list)
idx += 1
if idx % 10_000 == 0:
activity.heartbeat(f"processing idx: {idx}")
except Exception as exc:
msmt_str = msmt_dict.get("report_id", None)
if msmt:
Expand All @@ -149,7 +138,7 @@ def make_observations_for_file_entry_batch(
)
failure_count += 1

if params.fast_fail:
if fast_fail:
db.close()
raise exc
log.debug(f"done processing file s3://{bucket_name}/{s3path}")
Expand All @@ -166,30 +155,86 @@ def make_observations_for_file_entry_batch(
span.add_event(f"s3_path: s3://{bucket_name}/{s3path}")
total_failure_count += failure_count

current_span.set_attribute("total_runtime_ms", tbatch.ms)
current_span.set_attribute("total_failure_count", total_failure_count)
return idx


ObservationBatches = TypedDict(
"ObservationBatches",
{"batch_count": int, "total_size": int},
{"batches": List[List[FileEntryBatchType]], "total_size": int},
)


@activity.defn
def make_observation_batches(params: MakeObservationsParams) -> ObservationBatches:
day = datetime.strptime(params.bucket_date, "%Y-%m-%d").date()
def make_observation_batches(
bucket_date: str, probe_cc: List[str], test_name: List[str]
) -> ObservationBatches:
day = datetime.strptime(bucket_date, "%Y-%m-%d").date()

t = PerfTimer()
file_entry_batches, total_size = list_file_entries_batches(
probe_cc=params.probe_cc,
test_name=params.test_name,
probe_cc=probe_cc,
test_name=test_name,
start_day=day,
end_day=day + timedelta(days=1),
)
log.info(f"listing {len(file_entry_batches)} batches took {t.pretty}")
return {"batch_count": len(file_entry_batches), "total_size": total_size}
return {"batches": file_entry_batches, "total_size": total_size}


MakeObservationsResult = TypedDict(
"MakeObservationsResult",
{
"measurement_count": int,
"measurement_per_sec": float,
"mb_per_sec": float,
"total_size": int,
},
)


@activity.defn
async def make_observations(params: MakeObservationsParams) -> MakeObservationsResult:
loop = asyncio.get_running_loop()

tbatch = PerfTimer()
current_span = trace.get_current_span()
batches = await loop.run_in_executor(
None,
functools.partial(
make_observation_batches,
probe_cc=params.probe_cc,
test_name=params.test_name,
bucket_date=params.bucket_date,
),
)
awaitables = []
for file_entry_batch in batches["batches"]:
awaitables.append(
loop.run_in_executor(
None,
functools.partial(
make_observations_for_file_entry_batch,
file_entry_batch=file_entry_batch,
bucket_date=params.bucket_date,
probe_cc=params.probe_cc,
data_dir=pathlib.Path(params.data_dir),
clickhouse=params.clickhouse,
write_batch_size=1_000_000,
fast_fail=False,
),
)
)

measurement_count = sum(await asyncio.gather(*awaitables))

current_span.set_attribute("total_runtime_ms", tbatch.ms)
# current_span.set_attribute("total_failure_count", total_failure_count)

return {
"measurement_count": measurement_count,
"mb_per_sec": float(batches["total_size"]) / 1024 / 1024 / tbatch.s,
"measurement_per_sec": measurement_count / tbatch.s,
"total_size": batches["total_size"],
}


@dataclass
Expand Down
7 changes: 3 additions & 4 deletions oonipipeline/src/oonipipeline/temporal/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from oonipipeline.temporal.activities.observations import (
delete_previous_range,
get_previous_range,
make_observation_batches,
make_observations_for_file_entry_batch,
make_observations,
)
from oonipipeline.temporal.client_operations import (
TemporalConfig,
Expand All @@ -42,8 +41,7 @@
ACTIVTIES = [
delete_previous_range,
get_previous_range,
make_observation_batches,
make_observations_for_file_entry_batch,
make_observations,
make_ground_truths_in_day,
make_analysis_in_a_day,
optimize_all_tables,
Expand Down Expand Up @@ -75,6 +73,7 @@ def start_workers(temporal_config: TemporalConfig):
log.info(f"starting workers with max_workers={max_workers}")
thread_pool = ThreadPoolExecutor(max_workers=max_workers + 2)
loop = asyncio.new_event_loop()
loop.set_default_executor(thread_pool)
# TODO(art): Investigate if we want to upgrade to python 3.12 and use this
# instead
# loop.set_task_factory(asyncio.eager_task_factory)
Expand Down
45 changes: 9 additions & 36 deletions oonipipeline/src/oonipipeline/temporal/workflows/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from oonipipeline.temporal.activities.observations import (
DeletePreviousRangeParams,
GetPreviousRangeParams,
MakeObservationsFileEntryBatch,
MakeObservationsParams,
delete_previous_range,
get_previous_range,
make_observation_batches,
make_observations_for_file_entry_batch,
make_observations,
)
from oonipipeline.temporal.workflows.common import (
TASK_QUEUE_NAME,
Expand Down Expand Up @@ -87,40 +85,15 @@ async def run(self, params: ObservationsWorkflowParams) -> dict:
retry_policy=RetryPolicy(maximum_attempts=10),
)

batch_res = await workflow.execute_activity(
make_observation_batches,
obs_res = await workflow.execute_activity(
make_observations,
params_make_observations,
start_to_close_timeout=timedelta(minutes=30),
start_to_close_timeout=timedelta(hours=48),
retry_policy=RetryPolicy(maximum_attempts=3),
)

coroutine_list = []
for batch_idx in range(batch_res["batch_count"]):
batch_params = MakeObservationsFileEntryBatch(
batch_idx=batch_idx,
clickhouse=params.clickhouse,
write_batch_size=1_000_000,
data_dir=params.data_dir,
bucket_date=params.bucket_date,
probe_cc=params.probe_cc,
test_name=params.test_name,
fast_fail=params.fast_fail,
)
coroutine_list.append(
workflow.execute_activity(
make_observations_for_file_entry_batch,
batch_params,
task_queue=TASK_QUEUE_NAME,
start_to_close_timeout=timedelta(hours=10),
retry_policy=RetryPolicy(maximum_attempts=10),
)
)
total_msmt_count = sum(await asyncio.gather(*coroutine_list))

mb_per_sec = round(batch_res["total_size"] / total_t.s / 10**6, 1)
msmt_per_sec = round(total_msmt_count / total_t.s)
workflow.logger.info(
f"finished processing all batches in {total_t.pretty} speed: {mb_per_sec}MB/s ({msmt_per_sec}msmt/s)"
f"finished processing all batches in {total_t.pretty} speed: {obs_res['mb_per_sec']}MB/s ({obs_res['measurement_per_sec']}msmt/s)"
)

await workflow.execute_activity(
Expand All @@ -141,9 +114,9 @@ async def run(self, params: ObservationsWorkflowParams) -> dict:
)

return {
"measurement_count": total_msmt_count,
"size": batch_res["total_size"],
"mb_per_sec": mb_per_sec,
"measurement_count": obs_res["measurement_count"],
"size": obs_res["total_size"],
"mb_per_sec": obs_res["mb_per_sec"],
"bucket_date": params.bucket_date,
"msmt_per_sec": msmt_per_sec,
"measurement_per_sec": obs_res["measurement_per_sec"],
}
60 changes: 20 additions & 40 deletions oonipipeline/tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
GetPreviousRangeParams,
MakeObservationsFileEntryBatch,
MakeObservationsParams,
MakeObservationsResult,
ObservationBatches,
make_observations_for_file_entry_batch,
)
Expand Down Expand Up @@ -178,15 +179,13 @@ def test_make_file_entry_batch(datadir, db):
)
]
obs_msmt_count = make_observations_for_file_entry_batch(
params=MakeObservationsFileEntryBatch(
file_entry_batch=file_entry_batch,
clickhouse=db.clickhouse_url,
write_batch_size=1,
data_dir=datadir,
bucket_date="2023-10-31",
probe_cc=["IR"],
fast_fail=False,
)
file_entry_batch=file_entry_batch,
clickhouse=db.clickhouse_url,
write_batch_size=1,
data_dir=datadir,
bucket_date="2023-10-31",
probe_cc=["IR"],
fast_fail=False,
)
assert obs_msmt_count == 453
# Flush buffer table
Expand Down Expand Up @@ -377,34 +376,16 @@ async def get_obs_count_by_cc_mocked(params: ObsCountParams):
}


@activity.defn(name="make_observations_for_file_entry_batch")
async def make_observations_for_file_entry_batch_mocked(
params: MakeObservationsFileEntryBatch,
) -> int:
return 100


@activity.defn(name="make_observation_batches")
async def make_observation_batches_mocked(
@activity.defn(name="make_observations")
async def make_observations_mocked(
params: MakeObservationsParams,
) -> ObservationBatches:
return ObservationBatches(
batches=[
[
# FileEntryBatchType = Tuple[str, str, str, int]
# ((fe.bucket_name, fe.s3path, fe.ext, fe.size))
("ooni-data-eu-fra", "/dummy", ".tar.gz", 200),
("ooni-data-eu-fra", "/dummy-2", ".tar.gz", 1000),
],
[
# FileEntryBatchType = Tuple[str, str, str, int]
# ((fe.bucket_name, fe.s3path, fe.ext, fe.size))
("ooni-data-eu-fra", "/dummy", ".tar.gz", 200),
("ooni-data-eu-fra", "/dummy-2", ".tar.gz", 1000),
],
],
total_size=1200 * 2,
)
) -> MakeObservationsResult:
return {
"measurement_count": 100,
"measurement_per_sec": 3.0,
"mb_per_sec": 1.0,
"total_size": 2000,
}


@activity.defn(name="make_analysis_in_a_day")
Expand Down Expand Up @@ -436,8 +417,7 @@ async def test_temporal_workflows():
make_ground_truths_in_day_mocked,
get_obs_count_by_cc_mocked,
make_analysis_in_a_day_mocked,
make_observation_batches_mocked,
make_observations_for_file_entry_batch_mocked,
make_observations_mocked,
get_previous_range_mocked,
delete_previous_range_mocked,
],
Expand All @@ -448,8 +428,8 @@ async def test_temporal_workflows():
id="obs-wf",
task_queue=TASK_QUEUE_NAME,
)
assert res["size"] == 1200 * 2
assert res["measurement_count"] == 100 * 2
assert res["size"] == 2000
assert res["measurement_count"] == 100
assert res["bucket_date"] == "2024-01-02"

res = await env.client.execute_workflow(
Expand Down

0 comments on commit 3481d4a

Please sign in to comment.