Skip to content

Commit

Permalink
Added type preservation for set operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Nov 6, 2023
1 parent 2a16b36 commit 12e2512
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.3'
__version__ = '1.1.4'
36 changes: 31 additions & 5 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ class Event(EventMapType):
A map of variables to values of their respective domains.
"""

def check_same_type(self, other: Any):
"""
Check that both self and other are of the same type.
:param other: The other object
"""
if type(self) is not type(other):
raise TypeError(f"Cannot use operation on {type(self)} with {type(other)}")

def intersection(self, other: 'Event') -> 'Event':
"""
Get the intersection of this and another event.
Expand All @@ -67,7 +76,10 @@ def intersection(self, other: 'Event') -> 'Event':
:return: The intersection
"""
result = Event()

self.check_same_type(other)

result = self.__class__()

variables = set(self.keys()) | set(other.keys())

Expand Down Expand Up @@ -119,7 +131,10 @@ def union(self, other: 'Event') -> 'Event':
If one variable is only in one of the events, the union is the corresponding element.
"""
result = Event()

self.check_same_type(other)

result = self.__class__()

variables = set(self.keys()) | set(other.keys())

Expand Down Expand Up @@ -171,7 +186,10 @@ def difference(self, other: 'Event') -> 'Event':
If a variable appears only in `other`, it is assumed that `self` has the entire domain as default value.
"""
result = Event()

self.check_same_type(other)

result = self.__class__()

variables = set(self.keys()) | set(other.keys())

Expand Down Expand Up @@ -219,7 +237,7 @@ def complement(self) -> 'Event':
"""
Get the complement of this event.
"""
return Event() - self
return self.__class__() - self

__invert__ = complement
"""Alias for complement."""
Expand All @@ -231,6 +249,7 @@ def __eq__(self, other: 'Event') -> bool:
If one variable is only in one of the events, it is assumed that the other event has the entire domain as
default value.
"""

variables = set(self.keys()) | set(other.keys())

equal = True
Expand Down Expand Up @@ -318,8 +337,15 @@ class EncodedEvent(Event):
@staticmethod
def check_element(variable: Variable, element: Any) -> Union[tuple, portion.Interval]:

# if the variable is continuous, don't process the element
# if the variable is continuous
if isinstance(variable, Continuous):

# if it's not an interval
if not isinstance(element, portion.Interval):

# try to convert it to one
element = portion.singleton(element)

return element

# if its any kind of iterable that's not an interval convert it to a tuple
Expand Down
19 changes: 16 additions & 3 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,9 @@ def test_encode(self):
"""
Test that events are correctly encoded.
"""
print(self.event)
encoded = self.event.encode()
print(encoded)
self.assertIsInstance(encoded, EncodedEvent)
decoded = encoded.decode()
print(decoded)
self.assertEqual(self.event, decoded)

def test_intersection(self):
Expand Down Expand Up @@ -187,6 +184,16 @@ def test_equality(self):
self.assertEqual(self.event, self.event)
self.assertNotEqual(self.event, Event())

def test_raises_on_operation_with_different_types(self):
with self.assertRaises(TypeError):
self.event & self.event.encode()

with self.assertRaises(TypeError):
self.event | self.event.encode()

with self.assertRaises(TypeError):
self.event - self.event.encode()


class EncodedEventTestCase(unittest.TestCase):

Expand Down Expand Up @@ -238,6 +245,12 @@ def test_dict_like_creation(self):
self.assertEqual(event[self.integer], (0, 1))
self.assertEqual(event[self.symbol], (0,))

def test_set_operations_return_type(self):
event = EncodedEvent(zip([self.integer, self.symbol], [1, 0]))
self.assertEqual(type(event & event), EncodedEvent)
self.assertEqual(type(event | event), EncodedEvent)
self.assertEqual(type(event - event), EncodedEvent)


if __name__ == '__main__':
unittest.main()

0 comments on commit 12e2512

Please sign in to comment.