Skip to content

Commit

Permalink
Update process_commits example (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Sep 1, 2024
1 parent 3915599 commit 6149b14
Showing 1 changed file with 68 additions and 38 deletions.
106 changes: 68 additions & 38 deletions examples/firehose/process_commits.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import multiprocessing
import signal
import time
from collections import defaultdict
from types import FrameType
from typing import Any

from atproto import CAR, AtUri, FirehoseSubscribeReposClient, firehose_models, models, parse_subscribe_repos_message

_INTERESTED_RECORDS = {
models.ids.AppBskyFeedLike: models.AppBskyFeedLike,
models.ids.AppBskyFeedPost: models.AppBskyFeedPost,
models.ids.AppBskyGraphFollow: models.AppBskyGraphFollow,
}

def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict: # noqa: C901
operation_by_type = {
'posts': {'created': [], 'deleted': []},
'reposts': {'created': [], 'deleted': []},
'likes': {'created': [], 'deleted': []},
'follows': {'created': [], 'deleted': []},
}

def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defaultdict:
operation_by_type = defaultdict(lambda: {'created': [], 'deleted': []})

car = CAR.from_bytes(commit.blocks)
for op in commit.ops:
uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')

if op.action == 'update':
# not supported yet
continue

uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')

if op.action == 'create':
if not op.cid:
continue
Expand All @@ -30,37 +36,19 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict
continue

record = models.get_or_create(record_raw_data, strict=False)
if uri.collection == models.ids.AppBskyFeedLike and models.is_record_type(
record, models.ids.AppBskyFeedLike
):
operation_by_type['likes']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyFeedPost and models.is_record_type(
record, models.ids.AppBskyFeedPost
):
operation_by_type['posts']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyFeedRepost and models.is_record_type(
record, models.ids.AppBskyFeedRepost
):
operation_by_type['reposts']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyGraphFollow and models.is_record_type(
record, models.ids.AppBskyGraphFollow
):
operation_by_type['follows']['created'].append({'record': record, **create_info})
record_type = _INTERESTED_RECORDS.get(uri.collection)
if record_type and models.is_record_type(record, record_type):
operation_by_type[uri.collection]['created'].append({'record': record, **create_info})

if op.action == 'delete':
if uri.collection == models.ids.AppBskyFeedLike:
operation_by_type['likes']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyFeedPost:
operation_by_type['posts']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyFeedRepost:
operation_by_type['reposts']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyGraphFollow:
operation_by_type['follows']['deleted'].append({'uri': str(uri)})
operation_by_type[uri.collection]['deleted'].append({'uri': str(uri)})

return operation_by_type


def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing.Queue) -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) # we handle it in the main process

while True:
message = pool_queue.get()

Expand All @@ -75,17 +63,58 @@ def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing
continue

ops = _get_ops_by_type(commit)
for post in ops['posts']['created']:
post_msg = post['record'].text
post_langs = post['record'].langs
print(f'New post in the network! Langs: {post_langs}. Text: {post_msg}')
for created_post in ops[models.ids.AppBskyFeedPost]['created']:
author = created_post['author']
record = created_post['record']

inlined_text = record.text.replace('\n', ' ')
print(f'NEW POST [CREATED_AT={record.created_at}][AUTHOR={author}]: {inlined_text}')


def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprotoSyncSubscribeRepos.Params:
return models.ComAtprotoSyncSubscribeRepos.Params(cursor=cursor_value.value)


def measure_events_per_second(func: callable) -> callable:
def wrapper(*args) -> Any:
wrapper.calls += 1
cur_time = time.time()

if cur_time - wrapper.start_time >= 1:
print(f'NETWORK LOAD: {wrapper.calls} events/second')
wrapper.start_time = cur_time
wrapper.calls = 0

return func(*args)

wrapper.calls = 0
wrapper.start_time = time.time()

return wrapper


def signal_handler(_: int, __: FrameType) -> None:
print('Keyboard interrupt received. Waiting for the queue to empty before terminating processes...')

# Stop receiving new messages
client.stop()

# Drain the messages queue
while not queue.empty():
print('Waiting for the queue to empty...')
time.sleep(0.2)

print('Queue is empty. Gracefully terminating processes...')

pool.terminate()
pool.join()

exit(0)


if __name__ == '__main__':
signal.signal(signal.SIGINT, signal_handler)

start_cursor = None

params = None
Expand All @@ -97,11 +126,12 @@ def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprot
client = FirehoseSubscribeReposClient(params)

workers_count = multiprocessing.cpu_count() * 2 - 1
max_queue_size = 500
max_queue_size = 10000

queue = multiprocessing.Queue(maxsize=max_queue_size)
pool = multiprocessing.Pool(workers_count, worker_main, (cursor, queue))

@measure_events_per_second
def on_message_handler(message: firehose_models.MessageFrame) -> None:
if cursor.value:
# we are using updating the cursor state here because of multiprocessing
Expand Down

0 comments on commit 6149b14

Please sign in to comment.