Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement event ownership #423

Merged
merged 11 commits into from
Dec 14, 2023
41 changes: 34 additions & 7 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,10 @@ async def post_node(node: Node,
node.owner = current_user.username
obj = await db.create(node)
data = _get_node_event_data('created', obj)
await pubsub.publish_cloudevent('node', data)
attributes = {}
if data.get('owner', None):
attributes['owner'] = data['owner']
await pubsub.publish_cloudevent('node', data, attributes)
JenySadadia marked this conversation as resolved.
Show resolved Hide resolved
return obj


Expand Down Expand Up @@ -578,7 +581,10 @@ async def put_node(node_id: str, node: Node,

obj = await db.update(node)
data = _get_node_event_data('updated', obj)
await pubsub.publish_cloudevent('node', data)
attributes = {}
if data.get('owner', None):
attributes['owner'] = data['owner']
await pubsub.publish_cloudevent('node', data, attributes)
return obj


Expand All @@ -600,7 +606,10 @@ async def put_nodes(
await _set_node_ownership_recursively(user, nodes)
obj_list = await db.create_hierarchy(nodes, Node)
data = _get_node_event_data('updated', obj_list[0])
await pubsub.publish_cloudevent('node', data)
attributes = {}
if data.get('owner', None):
attributes['owner'] = data['owner']
await pubsub.publish_cloudevent('node', data, attributes)
return obj_list


Expand All @@ -610,26 +619,31 @@ async def put_nodes(
@app.post('/subscribe/{channel}', response_model=Subscription)
async def subscribe(channel: str, user: User = Depends(get_current_user)):
"""Subscribe handler for Pub/Sub channel"""
return await pubsub.subscribe(channel)
return await pubsub.subscribe(channel, user.username)


@app.post('/unsubscribe/{sub_id}')
async def unsubscribe(sub_id: int, user: User = Depends(get_current_user)):
"""Unsubscribe handler for Pub/Sub channel"""
try:
await pubsub.unsubscribe(sub_id)
await pubsub.unsubscribe(sub_id, user.username)
except KeyError as error:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Subscription id not found: {str(error)}"
) from error
except RuntimeError as error:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(error)
) from error


@app.get('/listen/{sub_id}')
async def listen(sub_id: int, user: User = Depends(get_current_user)):
"""Listen messages from a subscribed Pub/Sub channel"""
try:
return await pubsub.listen(sub_id)
return await pubsub.listen(sub_id, user.username)
except KeyError as error:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -646,7 +660,20 @@ async def listen(sub_id: int, user: User = Depends(get_current_user)):
async def publish(event: PublishEvent, channel: str,
user: User = Depends(get_current_user)):
"""Publish an event on the provided Pub/Sub channel"""
await pubsub.publish_cloudevent(channel, event.data, event.attributes)
event_dict = PublishEvent.dict(event)
# 1 - Extract data and attributes from the event
# 2 - Add the owner as an extra attribute
# 3 - Collect all the other extra attributes, if available, without
# overwriting any of the standard ones in the dict
data = event_dict.pop('data')
extra_attributes = event_dict.pop("attributes")
attributes = event_dict
attributes['owner'] = user.username
if extra_attributes:
for k in extra_attributes:
if k not in attributes:
attributes[k] = extra_attributes[k]
await pubsub.publish_cloudevent(channel, data, attributes)


@app.post('/push/{list_name}')
Expand Down
23 changes: 9 additions & 14 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,22 +369,17 @@ def get_model_from_kind(kind: str):
return models[kind]


class PublishAttributes(BaseModel):
"""API model for the attributes of a Publish operation"""
type: str = Field(
default='api.kernelci.org',
class PublishEvent(BaseModel):
"""API model for the data of a <publish> event"""
data: Any = Field(
description="Event payload"
)
type: Optional[str] = Field(
description="Type of the <publish> event"
)
source: str = Field(
source: Optional[str] = Field(
description="Source of the <publish> event"
)


class PublishEvent(BaseModel):
"""API model for the data of a <publish> event"""
attributes: Optional[PublishAttributes] = Field(
description="Event attributes"
)
data: Dict = Field(
description="Event payload"
attributes: Optional[Dict] = Field(
description="Extra Cloudevents Extension Context Attributes"
)
65 changes: 46 additions & 19 deletions api/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class Subscription(BaseModel):
channel: str = Field(
description='Subscription channel name'
)
user: str = Field(
description=("Username of the user that created the "
"subscription (owner)")
)


class PubSub:
Expand Down Expand Up @@ -49,6 +53,12 @@ def __init__(self, host=None, db_number=None):
if db_number is None:
db_number = self._settings.redis_db_number
self._redis = aioredis.from_url(f'redis://{host}/{db_number}')
# self._subscriptions is a dict that matches a subscription id
# (key) with a Subscription object ('sub') and a redis
# PubSub object ('redis_sub'). For instance:
# {1 : {'sub': <Subscription>, 'redis_sub': <PubSub>}}
#
# Note that this matching is kept in this dict only.
self._subscriptions = {}
self._channels = set()
self._lock = asyncio.Lock()
Expand Down Expand Up @@ -78,38 +88,44 @@ async def _keep_alive(self):
def _update_channels(self):
self._channels = set()
for sub in self._subscriptions.values():
for channel in sub.channels.keys():
for channel in sub['redis_sub'].channels.keys():
self._channels.add(channel.decode())

async def subscribe(self, channel):
async def subscribe(self, channel, user):
"""Subscribe to a Pub/Sub channel

Subscribe to a given channel and return a Subscription object
containing the subscription id which can then be used again in other
methods.
Subscribe to a given channel and return a Subscription object.
"""
sub_id = await self._redis.incr(self.ID_KEY)
async with self._lock:
sub = self._redis.pubsub()
self._subscriptions[sub_id] = sub
await sub.subscribe(channel)
redis_sub = self._redis.pubsub()
sub = Subscription(id=sub_id, channel=channel, user=user)
self._subscriptions[sub_id] = {'redis_sub': redis_sub,
'sub': sub}
await redis_sub.subscribe(channel)
self._update_channels()
self._start_keep_alive_timer()
return Subscription(id=sub_id, channel=channel)
return sub

async def unsubscribe(self, sub_id):
async def unsubscribe(self, sub_id, user=None):
"""Unsubscribe from a Pub/Sub channel

Unsubscribe from a channel using the provided subscription id as found
in a Subscription object.
"""
async with self._lock:
sub = self._subscriptions[sub_id]
r-c-n marked this conversation as resolved.
Show resolved Hide resolved
# Only allow a user to unsubscribe its own
# subscriptions. One exception: let an anonymous (internal)
# call to this function to unsubscribe any subscription
if user and user != sub['sub'].user:
raise RuntimeError(f"Subscription {sub_id} "
f"not owned by {user}")
self._subscriptions.pop(sub_id)
self._update_channels()
await sub.unsubscribe()
await sub['redis_sub'].unsubscribe()

async def listen(self, sub_id):
async def listen(self, sub_id, user=None):
"""Listen for Pub/Sub messages

Listen on a given subscription id asynchronously and return a message
Expand All @@ -118,12 +134,22 @@ async def listen(self, sub_id):
async with self._lock:
sub = self._subscriptions[sub_id]

# Only allow a user to listen to its own subscriptions. One
# exception: let an anonymous (internal) call to this function
# to listen to any subscription
if user and user != sub['sub'].user:
raise RuntimeError(f"Subscription {sub_id} "
f"not owned by {user}")
while True:
msg = await sub.get_message(
msg = await sub['redis_sub'].get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if msg is not None:
return msg
if msg is None:
continue
msg_data = json.loads(msg['data'])
if 'owner' in msg_data and msg_data['owner'] != sub['sub'].user:
r-c-n marked this conversation as resolved.
Show resolved Hide resolved
continue
return msg

async def publish(self, channel, message):
"""Publish a message on a channel
Expand Down Expand Up @@ -161,10 +187,11 @@ async def publish_cloudevent(self, channel, data, attributes=None):
for more details.
"""
if not attributes:
attributes = {
"type": "api.kernelci.org",
"source": self._settings.cloud_events_source,
}
attributes = {}
if not attributes.get('type'):
attributes['type'] = "api.kernelci.org"
if not attributes.get('source'):
attributes['source'] = self._settings.cloud_events_source
r-c-n marked this conversation as resolved.
Show resolved Hide resolved
event = CloudEvent(attributes=attributes, data=data)
await self.publish(channel, to_json(event))

Expand Down
6 changes: 3 additions & 3 deletions tests/e2e_tests/test_subscribe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
response = test_client.post(
"subscribe/node",
headers={
"Authorization": f"Bearer {pytest.BEARER_TOKEN}"

Check failure on line 26 in tests/e2e_tests/test_subscribe_handler.py

View workflow job for this annotation

GitHub Actions / Lint

Module 'pytest' has no 'BEARER_TOKEN' member
},
)
pytest.node_channel_subscription_id = response.json()['id']
assert response.status_code == 200
assert ('id', 'channel') == tuple(response.json().keys())
assert ('id', 'channel', 'user') == tuple(response.json().keys())
assert response.json().get('channel') == 'node'


Expand All @@ -46,12 +46,12 @@
response = test_client.post(
"subscribe/test_channel",
headers={
"Authorization": f"Bearer {pytest.BEARER_TOKEN}"

Check failure on line 49 in tests/e2e_tests/test_subscribe_handler.py

View workflow job for this annotation

GitHub Actions / Lint

Module 'pytest' has no 'BEARER_TOKEN' member
},
)
pytest.test_channel_subscription_id = response.json()['id']
assert response.status_code == 200
assert ('id', 'channel') == tuple(response.json().keys())
assert ('id', 'channel', 'user') == tuple(response.json().keys())
assert response.json().get('channel') == 'test_channel'


Expand All @@ -70,10 +70,10 @@
response = test_client.post(
"subscribe/user_group",
headers={
"Authorization": f"Bearer {pytest.BEARER_TOKEN}"

Check failure on line 73 in tests/e2e_tests/test_subscribe_handler.py

View workflow job for this annotation

GitHub Actions / Lint

Module 'pytest' has no 'BEARER_TOKEN' member
},
)
pytest.user_group_channel_subscription_id = response.json()['id']
assert response.status_code == 200
assert ('id', 'channel') == tuple(response.json().keys())
assert ('id', 'channel', 'user') == tuple(response.json().keys())
assert response.json().get('channel') == 'user_group'
6 changes: 4 additions & 2 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from api.models import UserGroup
from api.user_models import User
from api.pubsub import PubSub
from api.pubsub import PubSub, Subscription

BEARER_TOKEN = "Bearer \
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJib2IifQ.\
Expand Down Expand Up @@ -223,8 +223,10 @@ def mock_pubsub_subscriptions(mocker):
"""Mocks `_redis` and `_subscriptions` member of PubSub class instance"""
pubsub = PubSub()
redis_mock = fakeredis.aioredis.FakeRedis()
sub = Subscription(id=1, channel='test', user='test')
mocker.patch.object(pubsub, '_redis', redis_mock)
subscriptions_mock = dict({1: pubsub._redis.pubsub()})
subscriptions_mock = dict(
{1: {'sub': sub, 'redis_sub': pubsub._redis.pubsub()}})
mocker.patch.object(pubsub, '_subscriptions', subscriptions_mock)
return pubsub

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PubSub._subscriptions dict should have one entry. This entry's
key should be equal 1.
"""
result = await mock_pubsub.subscribe('CHANNEL')
result = await mock_pubsub.subscribe('CHANNEL', 'test')
assert result.channel == 'CHANNEL'
assert result.id == 1
assert len(mock_pubsub._subscriptions) == 1
Expand All @@ -48,7 +48,7 @@
await mock_pubsub._redis.set(mock_pubsub.ID_KEY, 0)
channels = ((1, 'CHANNEL1'), (2, 'CHANNEL2'), (3, 'CHANNEL3'))
for expected_id, expected_channel in channels:
result = await mock_pubsub.subscribe(expected_channel)
result = await mock_pubsub.subscribe(expected_channel, 'test')
assert result.channel == expected_channel
assert result.id == expected_id
assert len(mock_pubsub._subscriptions) == 3
Expand Down Expand Up @@ -93,20 +93,20 @@
published in the channel by the redis publisher.

Expected Results:
Validate that a json is sent to the channel and assert the json values from

Check warning on line 96 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

line too long (83 > 79 characters)
data and attributes parameters in Pubsub.publish_cloudevent(). There's no

Check warning on line 97 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

line too long (81 > 79 characters)
return value, but a json to be published in a channel.
"""

data = 'validate json'
attributes = { "specversion": "1.0", "id": "6878b661-96dc-4e93-8c92-26eb9ff8db64",

Check warning on line 102 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

whitespace after '{'

Check warning on line 102 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

line too long (87 > 79 characters)
"source": "https://api.kernelci.org/", "type": "api.kernelci.org",

Check warning on line 103 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

continuation line under-indented for visual indent
"time": "2022-01-31T21:29:29.675593+00:00"}

Check warning on line 104 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

continuation line under-indented for visual indent

await mock_pubsub_publish.publish_cloudevent('CHANNEL1', data, attributes)

expected_json = str.encode('{"specversion": "1.0", '\

Check warning on line 108 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

the backslash is redundant between brackets
'"id": "6878b661-96dc-4e93-8c92-26eb9ff8db64", "source": "https://api.kernelci.org/", '\

Check warning on line 109 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

continuation line under-indented for visual indent

Check warning on line 109 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

line too long (92 > 79 characters)

Check warning on line 109 in tests/unit_tests/test_pubsub.py

View workflow job for this annotation

GitHub Actions / Lint

the backslash is redundant between brackets
'"type": "api.kernelci.org", "time": "2022-01-31T21:29:29.675593+00:00", '\
'"data": "validate json"}')

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_subscribe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_subscribe_endpoint(mock_subscribe, test_client):
HTTP Response Code 200 OK
JSON with 'id' and 'channel' keys
"""
subscribe = Subscription(id=1, channel='abc')
subscribe = Subscription(id=1, channel='abc', user='test')
mock_subscribe.return_value = subscribe

response = test_client.post(
Expand All @@ -29,4 +29,4 @@ def test_subscribe_endpoint(mock_subscribe, test_client):
)
print("response.json()", response.json())
assert response.status_code == 200
assert ('id', 'channel') == tuple(response.json().keys())
assert ('id', 'channel', 'user') == tuple(response.json().keys())
Loading