diff --git a/examples/firehose/process_commits.py b/examples/firehose/process_commits.py index 579812e8..9fcaad0b 100644 --- a/examples/firehose/process_commits.py +++ b/examples/firehose/process_commits.py @@ -1,24 +1,30 @@ import multiprocessing +import signal +import time +from collections import defaultdict +from types import FrameType +from typing import Any from atproto import CAR, AtUri, FirehoseSubscribeReposClient, firehose_models, models, parse_subscribe_repos_message +_INTERESTED_RECORDS = { + models.ids.AppBskyFeedLike: models.AppBskyFeedLike, + models.ids.AppBskyFeedPost: models.AppBskyFeedPost, + models.ids.AppBskyGraphFollow: models.AppBskyGraphFollow, +} -def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict: # noqa: C901 - operation_by_type = { - 'posts': {'created': [], 'deleted': []}, - 'reposts': {'created': [], 'deleted': []}, - 'likes': {'created': [], 'deleted': []}, - 'follows': {'created': [], 'deleted': []}, - } + +def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defaultdict: + operation_by_type = defaultdict(lambda: {'created': [], 'deleted': []}) car = CAR.from_bytes(commit.blocks) for op in commit.ops: - uri = AtUri.from_str(f'at://{commit.repo}/{op.path}') - if op.action == 'update': # not supported yet continue + uri = AtUri.from_str(f'at://{commit.repo}/{op.path}') + if op.action == 'create': if not op.cid: continue @@ -30,37 +36,19 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict continue record = models.get_or_create(record_raw_data, strict=False) - if uri.collection == models.ids.AppBskyFeedLike and models.is_record_type( - record, models.ids.AppBskyFeedLike - ): - operation_by_type['likes']['created'].append({'record': record, **create_info}) - elif uri.collection == models.ids.AppBskyFeedPost and models.is_record_type( - record, models.ids.AppBskyFeedPost - ): - operation_by_type['posts']['created'].append({'record': record, **create_info}) - elif uri.collection == models.ids.AppBskyFeedRepost and models.is_record_type( - record, models.ids.AppBskyFeedRepost - ): - operation_by_type['reposts']['created'].append({'record': record, **create_info}) - elif uri.collection == models.ids.AppBskyGraphFollow and models.is_record_type( - record, models.ids.AppBskyGraphFollow - ): - operation_by_type['follows']['created'].append({'record': record, **create_info}) + record_type = _INTERESTED_RECORDS.get(uri.collection) + if record_type and models.is_record_type(record, record_type): + operation_by_type[uri.collection]['created'].append({'record': record, **create_info}) if op.action == 'delete': - if uri.collection == models.ids.AppBskyFeedLike: - operation_by_type['likes']['deleted'].append({'uri': str(uri)}) - elif uri.collection == models.ids.AppBskyFeedPost: - operation_by_type['posts']['deleted'].append({'uri': str(uri)}) - elif uri.collection == models.ids.AppBskyFeedRepost: - operation_by_type['reposts']['deleted'].append({'uri': str(uri)}) - elif uri.collection == models.ids.AppBskyGraphFollow: - operation_by_type['follows']['deleted'].append({'uri': str(uri)}) + operation_by_type[uri.collection]['deleted'].append({'uri': str(uri)}) return operation_by_type def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing.Queue) -> None: + signal.signal(signal.SIGINT, signal.SIG_IGN) # we handle it in the main process + while True: message = pool_queue.get() @@ -75,17 +63,58 @@ def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing continue ops = _get_ops_by_type(commit) - for post in ops['posts']['created']: - post_msg = post['record'].text - post_langs = post['record'].langs - print(f'New post in the network! Langs: {post_langs}. Text: {post_msg}') + for created_post in ops[models.ids.AppBskyFeedPost]['created']: + author = created_post['author'] + record = created_post['record'] + + inlined_text = record.text.replace('\n', ' ') + print(f'NEW POST [CREATED_AT={record.created_at}][AUTHOR={author}]: {inlined_text}') def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprotoSyncSubscribeRepos.Params: return models.ComAtprotoSyncSubscribeRepos.Params(cursor=cursor_value.value) +def measure_events_per_second(func: callable) -> callable: + def wrapper(*args) -> Any: + wrapper.calls += 1 + cur_time = time.time() + + if cur_time - wrapper.start_time >= 1: + print(f'NETWORK LOAD: {wrapper.calls} events/second') + wrapper.start_time = cur_time + wrapper.calls = 0 + + return func(*args) + + wrapper.calls = 0 + wrapper.start_time = time.time() + + return wrapper + + +def signal_handler(_: int, __: FrameType) -> None: + print('Keyboard interrupt received. Waiting for the queue to empty before terminating processes...') + + # Stop receiving new messages + client.stop() + + # Drain the messages queue + while not queue.empty(): + print('Waiting for the queue to empty...') + time.sleep(0.2) + + print('Queue is empty. Gracefully terminating processes...') + + pool.terminate() + pool.join() + + exit(0) + + if __name__ == '__main__': + signal.signal(signal.SIGINT, signal_handler) + start_cursor = None params = None @@ -97,11 +126,12 @@ def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprot client = FirehoseSubscribeReposClient(params) workers_count = multiprocessing.cpu_count() * 2 - 1 - max_queue_size = 500 + max_queue_size = 10000 queue = multiprocessing.Queue(maxsize=max_queue_size) pool = multiprocessing.Pool(workers_count, worker_main, (cursor, queue)) + @measure_events_per_second def on_message_handler(message: firehose_models.MessageFrame) -> None: if cursor.value: # we are using updating the cursor state here because of multiprocessing