diff --git a/db/ticket.py b/db/ticket.py
index 30c861c..84dc13d 100644
--- a/db/ticket.py
+++ b/db/ticket.py
@@ -1,4 +1,5 @@
import logging
+from threading import Lock
import pendulum
from mongoengine import (
@@ -16,6 +17,7 @@
class TicketPrice(Document):
+ lock = Lock()
endpoint = "stake/diff"
price = FloatField(required=True)
@@ -49,6 +51,7 @@ def _fetch_new_ticket_price(cls):
@classmethod
def get_last(cls):
+ cls.lock.acquire()
last_ticket_price = cls.objects.order_by('-datetime').first()
try:
@@ -59,4 +62,5 @@ def get_last(cls):
finally:
last_ticket_price.save()
+ cls.lock.release()
return last_ticket_price
diff --git a/db/update_message.py b/db/update_message.py
index 2d2fc1c..54bff85 100644
--- a/db/update_message.py
+++ b/db/update_message.py
@@ -1,4 +1,5 @@
from json import loads
+from functools import reduce
import pendulum
from mongoengine import (
@@ -19,6 +20,11 @@ class Amount(EmbeddedDocument):
def __str__(self):
return f"{self.value} DCR"
+ def __add__(self, other):
+ if not isinstance(other, Amount):
+ raise TypeError('Cannot add!')
+ return Amount(self._value + other._value)
+
def equal(self, other):
if not isinstance(other, Amount):
raise TypeError(f"Other must be {Amount}")
@@ -41,12 +47,9 @@ class Session(EmbeddedDocument):
def __str__(self):
string = f"{self.hash[:32]}:\t["
- total = 0
- for index, amount in enumerate(self.amounts):
- total += amount.value
- string += f"{amount}"
- string += ", " if index != len(self.amounts)-1 else "]"
- string += f"\nTotal: {total} DCR"
+ string += ", ".join([f"{amount}" for amount in self.amounts])
+ total = reduce(lambda a, b: a+b, self.amounts, Amount(0))
+ string += f"]\nTotal: {total}"
return string
def equal(self, other):
@@ -69,7 +72,7 @@ def from_data(cls, data):
class UpdateMessage(Document):
subject = ReferenceField(Subject, required=True)
- sessions = EmbeddedDocumentListField(Session, required=True)
+ sessions = EmbeddedDocumentListField(Session)
datetime = DateTimeField(default=pendulum.now, required=True)
meta = {
@@ -81,11 +84,10 @@ class UpdateMessage(Document):
def __str__(self):
string = f"{self.subject.header}\n\n"
- string += f"Default session: {self.subject.default_session}\n\n"
string += f"Ticket price: {TicketPrice.get_last()}\n\n"
- for index, msg in enumerate(self.sessions):
- string += f"{msg}
"
- string += "\n\n" if index != len(self.sessions) - 1 else ""
+ string += f"Default session: {self.subject.default_session}\n\n"
+ string += "\n\n".join([f"{session}
"
+ for session in self.sessions])
return string
def equal(self, other):
diff --git a/sws/client.py b/sws/client.py
index 5bfa1b9..1ebd93c 100644
--- a/sws/client.py
+++ b/sws/client.py
@@ -3,7 +3,6 @@
from threading import Thread, Lock
from websocket import WebSocketApp
-from mongoengine.errors import ValidationError
from bot.jack import JackBot
from db.subject import Subject
@@ -80,7 +79,7 @@ def on_message(ws: WebSocketApp, data):
sws.subject.notify(msg)
sws.lock.release()
logger.info(f'{sws.name} released lock!')
- except (ValidationError, DuplicatedUpdateMessageError) as e:
+ except DuplicatedUpdateMessageError as e:
logger.info(f"Supress {e} for creating {UpdateMessage} "
f"from {data} on {sws}")
diff --git a/tests/db/test_update_message.py b/tests/db/test_update_message.py
index ef562b6..1a0325d 100644
--- a/tests/db/test_update_message.py
+++ b/tests/db/test_update_message.py
@@ -1,10 +1,11 @@
-from unittest import TestCase
+from unittest import TestCase, mock
import pytest
from tests.fixtures import mongo # noqa F401
from db.update_message import UpdateMessage, Session, Amount
from db.subject import Subject
+from db.ticket import TicketPrice
from utils.exceptions import DuplicatedUpdateMessageError
@@ -28,29 +29,48 @@ def setUp(self) -> None:
def test_init(self):
self.assertEqual(UpdateMessage.objects.count(), 0)
- UpdateMessage(self.subject, [Session('test', [Amount(10)])]).save()
+ UpdateMessage(self.subject,
+ [Session('test', [Amount(1000000000)])]).save()
self.assertEqual(UpdateMessage.objects.count(), 1)
instance = UpdateMessage.objects.first()
self.assertEqual(instance.subject, self.subject)
self.assertIsInstance(instance, UpdateMessage)
+ @mock.patch('db.ticket.TicketPrice.get_last')
+ def test_str(self, mocked_get_last):
+ ticket_price = TicketPrice(10)
+ mocked_get_last.return_value = ticket_price
+
+ instance = UpdateMessage(self.subject,
+ [Session('test', [Amount(1000000000)])]).save()
+ self.assertEqual(instance.__str__(),
+ f"🇧🇷 Decred Brasil\n\n"
+ f"Ticket price: 10.00 DCR\n\n"
+ f"Default session: dcrbr1\n\n"
+ f"test:\t[10.0 DCR]\nTotal: 10.0 DCR
")
+
def test_equal(self):
- instance = UpdateMessage(self.subject, [Session('test', [Amount(10)])])
- other = UpdateMessage(self.subject, [Session('test', [Amount(10)])])
+ instance = UpdateMessage(self.subject,
+ [Session('test', [Amount(1000000000)])])
+ other = UpdateMessage(self.subject,
+ [Session('test', [Amount(1000000000)])])
self.assertTrue(instance.equal(other))
def test_equal_false(self):
- instance = UpdateMessage(self.subject, [Session('test', [Amount(10)])])
- other = UpdateMessage(self.subject, [Session('test', [Amount(11)])])
+ instance = UpdateMessage(self.subject,
+ [Session('test', [Amount(1000000000)])])
+ other = UpdateMessage(self.subject,
+ [Session('test', [Amount(1100000000)])])
self.assertFalse(instance.equal(other))
def test_get_last_by_subject(self):
- UpdateMessage(self.subject, [Session('test', [Amount(10)])]).save()
+ UpdateMessage(self.subject, [Session('test',
+ [Amount(1000000000)])]).save()
other = UpdateMessage(self.subject,
- [Session('test', [Amount(11)])]).save()
+ [Session('test', [Amount(1100000000)])]).save()
last = UpdateMessage.get_last_by_subject(self.subject)
self.assertEqual(other, last)