Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update process_commits example #374

Merged
merged 2 commits into from
Sep 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 68 additions & 38 deletions examples/firehose/process_commits.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import multiprocessing
import signal
import time
from collections import defaultdict
from types import FrameType
from typing import Any

from atproto import CAR, AtUri, FirehoseSubscribeReposClient, firehose_models, models, parse_subscribe_repos_message

_INTERESTED_RECORDS = {
models.ids.AppBskyFeedLike: models.AppBskyFeedLike,
models.ids.AppBskyFeedPost: models.AppBskyFeedPost,
models.ids.AppBskyGraphFollow: models.AppBskyGraphFollow,
}

def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict: # noqa: C901
operation_by_type = {
'posts': {'created': [], 'deleted': []},
'reposts': {'created': [], 'deleted': []},
'likes': {'created': [], 'deleted': []},
'follows': {'created': [], 'deleted': []},
}

def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defaultdict:
operation_by_type = defaultdict(lambda: {'created': [], 'deleted': []})

car = CAR.from_bytes(commit.blocks)
for op in commit.ops:
uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')

if op.action == 'update':
# not supported yet
continue

uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')

if op.action == 'create':
if not op.cid:
continue
Expand All @@ -30,37 +36,19 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> dict
continue

record = models.get_or_create(record_raw_data, strict=False)
if uri.collection == models.ids.AppBskyFeedLike and models.is_record_type(
record, models.ids.AppBskyFeedLike
):
operation_by_type['likes']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyFeedPost and models.is_record_type(
record, models.ids.AppBskyFeedPost
):
operation_by_type['posts']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyFeedRepost and models.is_record_type(
record, models.ids.AppBskyFeedRepost
):
operation_by_type['reposts']['created'].append({'record': record, **create_info})
elif uri.collection == models.ids.AppBskyGraphFollow and models.is_record_type(
record, models.ids.AppBskyGraphFollow
):
operation_by_type['follows']['created'].append({'record': record, **create_info})
record_type = _INTERESTED_RECORDS.get(uri.collection)
if record_type and models.is_record_type(record, record_type):
operation_by_type[uri.collection]['created'].append({'record': record, **create_info})

if op.action == 'delete':
if uri.collection == models.ids.AppBskyFeedLike:
operation_by_type['likes']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyFeedPost:
operation_by_type['posts']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyFeedRepost:
operation_by_type['reposts']['deleted'].append({'uri': str(uri)})
elif uri.collection == models.ids.AppBskyGraphFollow:
operation_by_type['follows']['deleted'].append({'uri': str(uri)})
operation_by_type[uri.collection]['deleted'].append({'uri': str(uri)})

return operation_by_type


def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing.Queue) -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) # we handle it in the main process

while True:
message = pool_queue.get()

Expand All @@ -75,17 +63,58 @@ def worker_main(cursor_value: multiprocessing.Value, pool_queue: multiprocessing
continue

ops = _get_ops_by_type(commit)
for post in ops['posts']['created']:
post_msg = post['record'].text
post_langs = post['record'].langs
print(f'New post in the network! Langs: {post_langs}. Text: {post_msg}')
for created_post in ops[models.ids.AppBskyFeedPost]['created']:
author = created_post['author']
record = created_post['record']

inlined_text = record.text.replace('\n', ' ')
print(f'NEW POST [CREATED_AT={record.created_at}][AUTHOR={author}]: {inlined_text}')


def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprotoSyncSubscribeRepos.Params:
return models.ComAtprotoSyncSubscribeRepos.Params(cursor=cursor_value.value)


def measure_events_per_second(func: callable) -> callable:
def wrapper(*args) -> Any:
wrapper.calls += 1
cur_time = time.time()

if cur_time - wrapper.start_time >= 1:
print(f'NETWORK LOAD: {wrapper.calls} events/second')
wrapper.start_time = cur_time
wrapper.calls = 0

return func(*args)

wrapper.calls = 0
wrapper.start_time = time.time()

return wrapper


def signal_handler(_: int, __: FrameType) -> None:
print('Keyboard interrupt received. Waiting for the queue to empty before terminating processes...')

# Stop receiving new messages
client.stop()

# Drain the messages queue
while not queue.empty():
print('Waiting for the queue to empty...')
time.sleep(0.2)

print('Queue is empty. Gracefully terminating processes...')

pool.terminate()
pool.join()

exit(0)


if __name__ == '__main__':
signal.signal(signal.SIGINT, signal_handler)

start_cursor = None

params = None
Expand All @@ -97,11 +126,12 @@ def get_firehose_params(cursor_value: multiprocessing.Value) -> models.ComAtprot
client = FirehoseSubscribeReposClient(params)

workers_count = multiprocessing.cpu_count() * 2 - 1
max_queue_size = 500
max_queue_size = 10000

queue = multiprocessing.Queue(maxsize=max_queue_size)
pool = multiprocessing.Pool(workers_count, worker_main, (cursor, queue))

@measure_events_per_second
def on_message_handler(message: firehose_models.MessageFrame) -> None:
if cursor.value:
# we are using updating the cursor state here because of multiprocessing
Expand Down
Loading