Skip to content

Commit

Permalink
Added behaviour for correct subclass initialization from json.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Nov 13, 2023
1 parent 12e2512 commit b2d6591
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.4'
__version__ = '1.1.5'
16 changes: 16 additions & 0 deletions src/random_events/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def get_full_class_name(cls):
"""
Returns the full name of a class, including the module name.
:param cls: The class.
:return: The full name of the class
"""
return cls.__module__ + "." + cls.__name__


def recursive_subclasses(cls):
"""
:param cls: The class.
:return: A list of the classes subclasses.
"""
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in recursive_subclasses(s)]
24 changes: 23 additions & 1 deletion src/random_events/variables.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import Any, Union, Iterable
from typing import Any, Union, Iterable, Dict

import portion
import pydantic

from . import utils


class Variable(pydantic.BaseModel):
"""
Expand All @@ -20,8 +22,14 @@ class Variable(pydantic.BaseModel):
The set of possible events of the variable.
"""

type: str = pydantic.Field(repr=False, init_var=False, default=None)
"""
The type of the variable. This is used for de-serialization and set automatically in the constructor.
"""

def __init__(self, name: str, domain: Any):
super().__init__(name=name, domain=domain)
self.type = utils.get_full_class_name(self.__class__)

def __lt__(self, other: "Variable") -> bool:
"""
Expand Down Expand Up @@ -74,6 +82,20 @@ def decode_many(self, elements: Iterable) -> Iterable[Any]:
"""
return elements

@staticmethod
def from_json(data: Dict[str, Any]) -> 'Variable':
"""
Create the correct instanceof the subclass from a json dict.
:param data: The json dict
:return: The correct instance of the subclass
"""
for subclass in utils.recursive_subclasses(Variable):
if utils.get_full_class_name(subclass) == data["type"]:
return subclass(**{key: value for key, value in data.items() if key != "type"})

raise ValueError("Unknown type for variable. Type is {}".format(data["type"]))


class Continuous(Variable):
"""
Expand Down
32 changes: 18 additions & 14 deletions test/test_variables.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import unittest

import portion

from random_events.variables import Integer, Symbolic, Continuous
from random_events.variables import Integer, Symbolic, Continuous, Variable


class VariablesTestCase(unittest.TestCase):
Expand Down Expand Up @@ -66,19 +67,6 @@ def test_to_json(self):
self.assertTrue(self.integer.model_dump_json())
self.assertTrue(self.real.model_dump_json())

def test_from_json(self):
"""
Test that the variables can be loaded from json.
"""
real = Continuous.model_validate_json(self.real.model_dump_json())
self.assertEqual(real, self.real)

integer = Integer.model_validate_json(self.integer.model_dump_json())
self.assertEqual(integer, self.integer)

symbol = Symbolic.model_validate_json(self.symbol.model_dump_json())
self.assertEqual(symbol, self.symbol)

def test_encode(self):
"""
Test that the variables can be encoded.
Expand All @@ -95,6 +83,22 @@ def test_decode(self):
self.assertEqual(self.symbol.decode(1), "b")
self.assertEqual(self.real.decode(1.0), 1.0)

def test_type_setting(self):
self.assertEqual(self.real.type, "random_events.variables.Continuous")
self.assertEqual(self.integer.type, "random_events.variables.Integer")
self.assertEqual(self.symbol.type, "random_events.variables.Symbolic")

def test_polymorphic_serialization(self):
real = Variable.from_json(json.loads(self.real.model_dump_json()))
self.assertEqual(real, self.real)

integer = Variable.from_json(json.loads(self.integer.model_dump_json()))
print(integer)
self.assertEqual(integer, self.integer)

symbol = Variable.from_json(json.loads(self.symbol.model_dump_json()))
self.assertEqual(symbol, self.symbol)


if __name__ == '__main__':
unittest.main()

0 comments on commit b2d6591

Please sign in to comment.