From 9ef066dc2fc8aa40afb4260d19fdd97b470f5f63 Mon Sep 17 00:00:00 2001 From: mandos21 Date: Thu, 22 Aug 2024 13:54:55 -0400 Subject: [PATCH] Added sockets and tests --- app/__init__.py | 2 ++ app/utils/auth_utils.py | 43 ++++++++++++---------- app/views/socket.py | 47 ++++++++++++++++++++++++ tests/views/test_socket.py | 74 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 app/views/socket.py create mode 100644 tests/views/test_socket.py diff --git a/app/__init__.py b/app/__init__.py index 435da15..2bc33c5 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -21,6 +21,8 @@ def create_app(): + from app.views import socket + flask_app = Flask(__name__) flask_app.config.from_object(settings) flask_app.json = CustomJSONProvider(flask_app) diff --git a/app/utils/auth_utils.py b/app/utils/auth_utils.py index 2c7de97..9fb7954 100644 --- a/app/utils/auth_utils.py +++ b/app/utils/auth_utils.py @@ -45,30 +45,35 @@ def decode_auth_token(auth_token): abort(403, description="Invalid token. Please log in again.") -def token_required(dm_required=False): - def decorator(f): - @wraps(f) - def decorated_function(*args, **kwargs): - auth_header = request.headers.get("Authorization", None) - if auth_header is None: - logger.warning("Authorization header is missing") - abort(403, description="Authorization header is missing!") +def validate_header(): + auth_header = request.headers.get("Authorization", None) + if auth_header is None: + logger.warning("Authorization header is missing") + abort(403, description="Authorization header is missing!") - try: - auth_type, token = auth_header.split() - if auth_type.lower() != "bearer": - logger.warning("Invalid token type: %s", auth_type) - abort( + try: + auth_type, token = auth_header.split() + if auth_type.lower() != "bearer": + logger.warning("Invalid token type: %s", auth_type) + abort( 403, description="Invalid token type. Expected Bearer token", ) - except ValueError: - logger.warning("Invalid Authorization header format") - abort(403, description="Invalid Authorization header format") + except ValueError: + logger.warning("Invalid Authorization header format") + abort(403, description="Invalid Authorization header format") - if not token: - logger.warning("Token is missing") - abort(403, description="Token is missing!") + if not token: + logger.warning("Token is missing") + abort(403, description="Token is missing!") + return token + + +def token_required(dm_required=False): + def decorator(f): + @wraps(f) + def decorated_function(*args, **kwargs): + token = validate_header() try: payload = decode_auth_token(token) diff --git a/app/views/socket.py b/app/views/socket.py new file mode 100644 index 0000000..b05fac8 --- /dev/null +++ b/app/views/socket.py @@ -0,0 +1,47 @@ +from flask_socketio import emit + +from app import socketio +from app.models.user import User +from app.utils.auth_utils import token_required + + +@socketio.on('connect', namespace='/socket') +@token_required() +def handle_connect(current_user, **kwargs): + username = current_user.username + emit('system_message', {'msg': f'{username} has connected.'}, broadcast=True) + + +@socketio.on('disconnect', namespace='/socket') +@token_required() +def handle_disconnect(current_user, **kwargs): + username = current_user.username + emit('system_message', {'msg': f'{username} has disconnected.'}, broadcast=True) + + +@socketio.on('system_message', namespace='/socket') +@token_required() +def handle_system_message(data, **kwargs): + message = data.get('message') + if message: + emit('system_message', {'msg': f'{kwargs["current_user"].username}: {message}'}, broadcast=True) + + +@socketio.on('dm_message', namespace='/socket') +@token_required(dm_required=True) +def handle_dm_message(data, **kwargs): + user_ids = data.get('user_ids', []) + message = data.get('message') + + if not message: + return + + if user_ids: + # Send the message to specific users + for uid in user_ids: + user = User.objects(_id=uid).first() + if user: + emit('dm_message', {'msg': message}, to=user.uid) + else: + # Broadcast the message to everyone + emit('dm_message', {'msg': message}, broadcast=True) diff --git a/tests/views/test_socket.py b/tests/views/test_socket.py new file mode 100644 index 0000000..a45fdd3 --- /dev/null +++ b/tests/views/test_socket.py @@ -0,0 +1,74 @@ +import unittest + +from app import socketio +from app.models.user import User +from tests.views.test_view_base import ControllerTestBase + + +class SocketNamespaceTestCase(ControllerTestBase): + + def setUp(self): + super().setUp() + self.socket_client = socketio.test_client( + self.app, + namespace='/socket', + headers={'Authorization': f'Bearer {self.token}'} + ) + + def test_connect_disconnect(self): + # Test connection (client connects automatically upon instantiation) + received = self.socket_client.get_received('/socket') + self.assertEqual(len(received), 1) + self.assertEqual(received[0]['name'], 'system_message') + self.assertIn('dmuser has connected.', received[0]['args'][0]['msg']) + + # Test disconnection + self.socket_client.disconnect(namespace='/socket') + self.assertFalse(self.socket_client.is_connected(namespace='/socket')) + + def test_system_message(self): + self.socket_client.get_received('/socket') # get the connect message out of the received + # Test sending a system message + self.socket_client.emit('system_message', {'message': 'Brandon has chosen the steel longsword'}, + namespace='/socket') + received = self.socket_client.get_received('/socket') + self.assertEqual(len(received), 1) + self.assertEqual(received[0]['name'], 'system_message') + self.assertIn('Brandon has chosen the steel longsword', received[0]['args'][0]['msg']) + + def test_dm_message_to_all(self): + self.socket_client.get_received('/socket') # get the connect message out of the received + # Test DM message to all users + self.socket_client.emit('dm_message', {'message': 'You may now roll for an item!'}, namespace='/socket') + received = self.socket_client.get_received('/socket') + self.assertEqual(len(received), 1) + self.assertEqual(received[0]['name'], 'dm_message') + self.assertIn('You may now roll for an item!', received[0]['args'][0]['msg']) + + def test_dm_message_to_specific_users(self): + self.socket_client.get_received('/socket') # get the connect message out of the received + # Create additional users + players = [] + for i in range(0, 2): + players.append(User(username=f'player{i}', email=f'player{i}@example.com')) + players[i].set_password('password') + players[i].save() + + # Join players to their rooms based on their IDs + self.socket_client.emit('join_room', {'username': 'player1', 'room': str(players[0].id)}, namespace='/socket') + self.socket_client.emit('join_room', {'username': 'player2', 'room': str(players[1].id)}, namespace='/socket') + + # Test DM message to specific users + self.socket_client.emit('dm_message', { + 'message': 'You have been granted permission to roll!', + 'usernames': ['player1', 'player2'] + }, namespace='/socket') + + received = self.socket_client.get_received('/socket') + self.assertEqual(len(received), 1) + self.assertEqual(received[0]['name'], 'dm_message') + self.assertIn('You have been granted permission to roll!', received[0]['args'][0]['msg']) + + +if __name__ == '__main__': + unittest.main()