Skip to content

Commit

Permalink
Added sockets and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mandos21 committed Aug 22, 2024
1 parent 3a0e15d commit 9ef066d
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 19 deletions.
2 changes: 2 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@


def create_app():
from app.views import socket

flask_app = Flask(__name__)
flask_app.config.from_object(settings)
flask_app.json = CustomJSONProvider(flask_app)
Expand Down
43 changes: 24 additions & 19 deletions app/utils/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,35 @@ def decode_auth_token(auth_token):
abort(403, description="Invalid token. Please log in again.")


def token_required(dm_required=False):
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
auth_header = request.headers.get("Authorization", None)
if auth_header is None:
logger.warning("Authorization header is missing")
abort(403, description="Authorization header is missing!")
def validate_header():
auth_header = request.headers.get("Authorization", None)
if auth_header is None:
logger.warning("Authorization header is missing")
abort(403, description="Authorization header is missing!")

try:
auth_type, token = auth_header.split()
if auth_type.lower() != "bearer":
logger.warning("Invalid token type: %s", auth_type)
abort(
try:
auth_type, token = auth_header.split()
if auth_type.lower() != "bearer":
logger.warning("Invalid token type: %s", auth_type)
abort(
403,
description="Invalid token type. Expected Bearer token",
)
except ValueError:
logger.warning("Invalid Authorization header format")
abort(403, description="Invalid Authorization header format")
except ValueError:
logger.warning("Invalid Authorization header format")
abort(403, description="Invalid Authorization header format")

if not token:
logger.warning("Token is missing")
abort(403, description="Token is missing!")
if not token:
logger.warning("Token is missing")
abort(403, description="Token is missing!")
return token


def token_required(dm_required=False):
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
token = validate_header()

try:
payload = decode_auth_token(token)
Expand Down
47 changes: 47 additions & 0 deletions app/views/socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from flask_socketio import emit

from app import socketio
from app.models.user import User
from app.utils.auth_utils import token_required


@socketio.on('connect', namespace='/socket')
@token_required()
def handle_connect(current_user, **kwargs):
username = current_user.username
emit('system_message', {'msg': f'{username} has connected.'}, broadcast=True)


@socketio.on('disconnect', namespace='/socket')
@token_required()
def handle_disconnect(current_user, **kwargs):
username = current_user.username
emit('system_message', {'msg': f'{username} has disconnected.'}, broadcast=True)


@socketio.on('system_message', namespace='/socket')
@token_required()
def handle_system_message(data, **kwargs):
message = data.get('message')
if message:
emit('system_message', {'msg': f'{kwargs["current_user"].username}: {message}'}, broadcast=True)


@socketio.on('dm_message', namespace='/socket')
@token_required(dm_required=True)
def handle_dm_message(data, **kwargs):
user_ids = data.get('user_ids', [])
message = data.get('message')

if not message:
return

if user_ids:
# Send the message to specific users
for uid in user_ids:
user = User.objects(_id=uid).first()
if user:
emit('dm_message', {'msg': message}, to=user.uid)
else:
# Broadcast the message to everyone
emit('dm_message', {'msg': message}, broadcast=True)
74 changes: 74 additions & 0 deletions tests/views/test_socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest

from app import socketio
from app.models.user import User
from tests.views.test_view_base import ControllerTestBase


class SocketNamespaceTestCase(ControllerTestBase):

def setUp(self):
super().setUp()
self.socket_client = socketio.test_client(
self.app,
namespace='/socket',
headers={'Authorization': f'Bearer {self.token}'}
)

def test_connect_disconnect(self):
# Test connection (client connects automatically upon instantiation)
received = self.socket_client.get_received('/socket')
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['name'], 'system_message')
self.assertIn('dmuser has connected.', received[0]['args'][0]['msg'])

# Test disconnection
self.socket_client.disconnect(namespace='/socket')
self.assertFalse(self.socket_client.is_connected(namespace='/socket'))

def test_system_message(self):
self.socket_client.get_received('/socket') # get the connect message out of the received
# Test sending a system message
self.socket_client.emit('system_message', {'message': 'Brandon has chosen the steel longsword'},
namespace='/socket')
received = self.socket_client.get_received('/socket')
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['name'], 'system_message')
self.assertIn('Brandon has chosen the steel longsword', received[0]['args'][0]['msg'])

def test_dm_message_to_all(self):
self.socket_client.get_received('/socket') # get the connect message out of the received
# Test DM message to all users
self.socket_client.emit('dm_message', {'message': 'You may now roll for an item!'}, namespace='/socket')
received = self.socket_client.get_received('/socket')
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['name'], 'dm_message')
self.assertIn('You may now roll for an item!', received[0]['args'][0]['msg'])

def test_dm_message_to_specific_users(self):
self.socket_client.get_received('/socket') # get the connect message out of the received
# Create additional users
players = []
for i in range(0, 2):
players.append(User(username=f'player{i}', email=f'player{i}@example.com'))
players[i].set_password('password')
players[i].save()

# Join players to their rooms based on their IDs
self.socket_client.emit('join_room', {'username': 'player1', 'room': str(players[0].id)}, namespace='/socket')
self.socket_client.emit('join_room', {'username': 'player2', 'room': str(players[1].id)}, namespace='/socket')

# Test DM message to specific users
self.socket_client.emit('dm_message', {
'message': 'You have been granted permission to roll!',
'usernames': ['player1', 'player2']
}, namespace='/socket')

received = self.socket_client.get_received('/socket')
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['name'], 'dm_message')
self.assertIn('You have been granted permission to roll!', received[0]['args'][0]['msg'])


if __name__ == '__main__':
unittest.main()

0 comments on commit 9ef066d

Please sign in to comment.