Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add process_commits_async example #377

Merged
merged 1 commit into from
Sep 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/firehose/process_commits_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
import signal
import time
from collections import defaultdict
from types import FrameType
from typing import Any

from atproto import (
CAR,
AsyncFirehoseSubscribeReposClient,
AtUri,
firehose_models,
models,
parse_subscribe_repos_message,
)

_INTERESTED_RECORDS = {
models.ids.AppBskyFeedLike: models.AppBskyFeedLike,
models.ids.AppBskyFeedPost: models.AppBskyFeedPost,
models.ids.AppBskyGraphFollow: models.AppBskyGraphFollow,
}


def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defaultdict:
operation_by_type = defaultdict(lambda: {'created': [], 'deleted': []})

car = CAR.from_bytes(commit.blocks)
for op in commit.ops:
if op.action == 'update':
# not supported yet
continue

uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')

if op.action == 'create':
if not op.cid:
continue

create_info = {'uri': str(uri), 'cid': str(op.cid), 'author': commit.repo}

record_raw_data = car.blocks.get(op.cid)
if not record_raw_data:
continue

record = models.get_or_create(record_raw_data, strict=False)
record_type = _INTERESTED_RECORDS.get(uri.collection)
if record_type and models.is_record_type(record, record_type):
operation_by_type[uri.collection]['created'].append({'record': record, **create_info})

if op.action == 'delete':
operation_by_type[uri.collection]['deleted'].append({'uri': str(uri)})

return operation_by_type


def measure_events_per_second(func: callable) -> callable:
def wrapper(*args) -> Any:
wrapper.calls += 1
cur_time = time.time()

if cur_time - wrapper.start_time >= 1:
print(f'NETWORK LOAD: {wrapper.calls} events/second')
wrapper.start_time = cur_time
wrapper.calls = 0

return func(*args)

wrapper.calls = 0
wrapper.start_time = time.time()

return wrapper


async def signal_handler(_: int, __: FrameType) -> None:
print('Keyboard interrupt received. Stopping...')

# Stop receiving new messages
await client.stop()


async def main(firehose_client: AsyncFirehoseSubscribeReposClient) -> None:
@measure_events_per_second
async def on_message_handler(message: firehose_models.MessageFrame) -> None:
commit = parse_subscribe_repos_message(message)
if not isinstance(commit, models.ComAtprotoSyncSubscribeRepos.Commit):
return

if commit.seq % 20 == 0:
firehose_client.update_params(models.ComAtprotoSyncSubscribeRepos.Params(cursor=commit.seq))

if not commit.blocks:
return

ops = _get_ops_by_type(commit)
for created_post in ops[models.ids.AppBskyFeedPost]['created']:
author = created_post['author']
record = created_post['record']

inlined_text = record.text.replace('\n', ' ')
print(f'NEW POST [CREATED_AT={record.created_at}][AUTHOR={author}]: {inlined_text}')

await client.start(on_message_handler)


if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda _, __: asyncio.create_task(signal_handler(_, __)))

start_cursor = None

params = None
if start_cursor is not None:
params = models.ComAtprotoSyncSubscribeRepos.Params(cursor=start_cursor)

client = AsyncFirehoseSubscribeReposClient(params)

# use run() for a higher Python version
asyncio.get_event_loop().run_until_complete(main(client))
Loading