diff --git a/db/ticket.py b/db/ticket.py index 30c861c..84dc13d 100644 --- a/db/ticket.py +++ b/db/ticket.py @@ -1,4 +1,5 @@ import logging +from threading import Lock import pendulum from mongoengine import ( @@ -16,6 +17,7 @@ class TicketPrice(Document): + lock = Lock() endpoint = "stake/diff" price = FloatField(required=True) @@ -49,6 +51,7 @@ def _fetch_new_ticket_price(cls): @classmethod def get_last(cls): + cls.lock.acquire() last_ticket_price = cls.objects.order_by('-datetime').first() try: @@ -59,4 +62,5 @@ def get_last(cls): finally: last_ticket_price.save() + cls.lock.release() return last_ticket_price diff --git a/db/update_message.py b/db/update_message.py index 2d2fc1c..54bff85 100644 --- a/db/update_message.py +++ b/db/update_message.py @@ -1,4 +1,5 @@ from json import loads +from functools import reduce import pendulum from mongoengine import ( @@ -19,6 +20,11 @@ class Amount(EmbeddedDocument): def __str__(self): return f"{self.value} DCR" + def __add__(self, other): + if not isinstance(other, Amount): + raise TypeError('Cannot add!') + return Amount(self._value + other._value) + def equal(self, other): if not isinstance(other, Amount): raise TypeError(f"Other must be {Amount}") @@ -41,12 +47,9 @@ class Session(EmbeddedDocument): def __str__(self): string = f"{self.hash[:32]}:\t[" - total = 0 - for index, amount in enumerate(self.amounts): - total += amount.value - string += f"{amount}" - string += ", " if index != len(self.amounts)-1 else "]" - string += f"\nTotal: {total} DCR" + string += ", ".join([f"{amount}" for amount in self.amounts]) + total = reduce(lambda a, b: a+b, self.amounts, Amount(0)) + string += f"]\nTotal: {total}" return string def equal(self, other): @@ -69,7 +72,7 @@ def from_data(cls, data): class UpdateMessage(Document): subject = ReferenceField(Subject, required=True) - sessions = EmbeddedDocumentListField(Session, required=True) + sessions = EmbeddedDocumentListField(Session) datetime = DateTimeField(default=pendulum.now, required=True) meta = { @@ -81,11 +84,10 @@ class UpdateMessage(Document): def __str__(self): string = f"{self.subject.header}\n\n" - string += f"Default session: {self.subject.default_session}\n\n" string += f"Ticket price: {TicketPrice.get_last()}\n\n" - for index, msg in enumerate(self.sessions): - string += f"{msg}" - string += "\n\n" if index != len(self.sessions) - 1 else "" + string += f"Default session: {self.subject.default_session}\n\n" + string += "\n\n".join([f"{session}" + for session in self.sessions]) return string def equal(self, other): diff --git a/sws/client.py b/sws/client.py index 5bfa1b9..1ebd93c 100644 --- a/sws/client.py +++ b/sws/client.py @@ -3,7 +3,6 @@ from threading import Thread, Lock from websocket import WebSocketApp -from mongoengine.errors import ValidationError from bot.jack import JackBot from db.subject import Subject @@ -80,7 +79,7 @@ def on_message(ws: WebSocketApp, data): sws.subject.notify(msg) sws.lock.release() logger.info(f'{sws.name} released lock!') - except (ValidationError, DuplicatedUpdateMessageError) as e: + except DuplicatedUpdateMessageError as e: logger.info(f"Supress {e} for creating {UpdateMessage} " f"from {data} on {sws}") diff --git a/tests/db/test_update_message.py b/tests/db/test_update_message.py index ef562b6..1a0325d 100644 --- a/tests/db/test_update_message.py +++ b/tests/db/test_update_message.py @@ -1,10 +1,11 @@ -from unittest import TestCase +from unittest import TestCase, mock import pytest from tests.fixtures import mongo # noqa F401 from db.update_message import UpdateMessage, Session, Amount from db.subject import Subject +from db.ticket import TicketPrice from utils.exceptions import DuplicatedUpdateMessageError @@ -28,29 +29,48 @@ def setUp(self) -> None: def test_init(self): self.assertEqual(UpdateMessage.objects.count(), 0) - UpdateMessage(self.subject, [Session('test', [Amount(10)])]).save() + UpdateMessage(self.subject, + [Session('test', [Amount(1000000000)])]).save() self.assertEqual(UpdateMessage.objects.count(), 1) instance = UpdateMessage.objects.first() self.assertEqual(instance.subject, self.subject) self.assertIsInstance(instance, UpdateMessage) + @mock.patch('db.ticket.TicketPrice.get_last') + def test_str(self, mocked_get_last): + ticket_price = TicketPrice(10) + mocked_get_last.return_value = ticket_price + + instance = UpdateMessage(self.subject, + [Session('test', [Amount(1000000000)])]).save() + self.assertEqual(instance.__str__(), + f"🇧🇷 Decred Brasil\n\n" + f"Ticket price: 10.00 DCR\n\n" + f"Default session: dcrbr1\n\n" + f"test:\t[10.0 DCR]\nTotal: 10.0 DCR") + def test_equal(self): - instance = UpdateMessage(self.subject, [Session('test', [Amount(10)])]) - other = UpdateMessage(self.subject, [Session('test', [Amount(10)])]) + instance = UpdateMessage(self.subject, + [Session('test', [Amount(1000000000)])]) + other = UpdateMessage(self.subject, + [Session('test', [Amount(1000000000)])]) self.assertTrue(instance.equal(other)) def test_equal_false(self): - instance = UpdateMessage(self.subject, [Session('test', [Amount(10)])]) - other = UpdateMessage(self.subject, [Session('test', [Amount(11)])]) + instance = UpdateMessage(self.subject, + [Session('test', [Amount(1000000000)])]) + other = UpdateMessage(self.subject, + [Session('test', [Amount(1100000000)])]) self.assertFalse(instance.equal(other)) def test_get_last_by_subject(self): - UpdateMessage(self.subject, [Session('test', [Amount(10)])]).save() + UpdateMessage(self.subject, [Session('test', + [Amount(1000000000)])]).save() other = UpdateMessage(self.subject, - [Session('test', [Amount(11)])]).save() + [Session('test', [Amount(1100000000)])]).save() last = UpdateMessage.get_last_by_subject(self.subject) self.assertEqual(other, last)