diff --git a/src/storage/model.cpp b/src/storage/model.cpp index bfdb8c74e..f2fbb8430 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -283,7 +283,7 @@ void define_sparse_model(py::module& m, std::string const& vtSuffix) { .def("convert_to_ctmc", &SparseMarkovAutomaton::convertToCtmc, "Convert the MA into a CTMC.") ; - py::class_, std::shared_ptr>>(m, ("Sparse" + vtSuffix + "SMG").c_str(), "SMG in sparse representation", nondetModel) + py::class_, std::shared_ptr>>(m, ("Sparse" + vtSuffix + "Smg").c_str(), "SMG in sparse representation", nondetModel) .def(py::init>(), py::arg("other_model")) .def(py::init const&>(), py::arg("components")) .def("get_state_player_indications", &SparseSmg::getStatePlayerIndications, "Get for each state its corresponding player") diff --git a/tests/storage/test_model_components.py b/tests/storage/test_model_components.py index 2546f3cf1..d8d537a09 100644 --- a/tests/storage/test_model_components.py +++ b/tests/storage/test_model_components.py @@ -875,3 +875,97 @@ def create_number(num): # Test choice labeling assert not dtmc.has_choice_labeling() + + + def test_build_smg(self): + nr_states = 7 + nr_choices = 10 + + # Build transition matrix + builder = stormpy.SparseMatrixBuilder(rows=0, columns=0, entries=0, force_dimensions=False, + has_custom_row_grouping=True, row_groups=0) + + # Row group, state 0 + builder.new_row_group(0) + builder.add_next_value(0, 0, 0.5) + builder.add_next_value(0, 1, 0.5) + builder.add_next_value(1, 1, 0.2) + builder.add_next_value(1, 2, 0.8) + # Row group, state 1 + builder.new_row_group(2) + builder.add_next_value(2, 2, 1) + builder.add_next_value(3, 3, 0.5) + builder.add_next_value(3, 6, 0.5) + # Row group, state 2 + builder.new_row_group(4) + builder.add_next_value(4, 3, 0.5) + builder.add_next_value(4, 4, 0.5) + builder.add_next_value(5, 0, 1) + # Row group, state 3 + builder.new_row_group(6) + builder.add_next_value(6, 5, 1) + # Row group, state 4 + builder.new_row_group(7) + builder.add_next_value(7, 5, 1) + # Row group, state 5, goal state + builder.new_row_group(8) + builder.add_next_value(8, 5, 1) + # Row group, state 6, deadlock state + builder.new_row_group(9) + builder.add_next_value(9, 6, 1) + + transition_matrix = builder.build(nr_choices, nr_states) + + # State labeling + state_labeling = stormpy.storage.StateLabeling(nr_states) + labels = {'init_p1', 'bad', 'done'} + for label in labels: + state_labeling.add_label(label) + state_labeling.add_label_to_state('init_p1', 0) + state_labeling.add_label_to_state('done', 5) + state_labeling.add_label_to_state('bad', 6) + + # Choice labeling + choice_labeling = stormpy.storage.ChoiceLabeling(nr_choices) + choice_labels = {'a', 'b', 'c', 'd'} + for label in choice_labels: + choice_labeling.add_label(label) + choice_labeling.add_label_to_choice('a', 0) + choice_labeling.add_label_to_choice('b', 1) + choice_labeling.add_label_to_choice('c', 2) + choice_labeling.add_label_to_choice('d', 3) + choice_labeling.add_label_to_choice('a', 4) + choice_labeling.add_label_to_choice('b', 5) + + # State player indications + state_player_indications = [0, 1, 2, 0, 2, 1, 0] + + components = stormpy.SparseModelComponents(transition_matrix=transition_matrix, state_labeling=state_labeling) + components.choice_labeling = choice_labeling + components.state_player_indications = state_player_indications + + # Build SMG + smg = stormpy.storage.SparseSmg(components) + + assert type(smg) is stormpy.SparseSmg + assert smg.model_type == stormpy.storage.ModelType.SMG + + # Test transition matrix + assert smg.nr_choices == nr_choices + assert smg.nr_states == nr_states + assert smg.nr_transitions == 14 + assert smg.transition_matrix.nr_entries == smg.nr_transitions + for state in smg.states: + assert len(state.actions) <= 2 + + # Test state labeling + assert smg.labeling.get_labels() == {'init_p1', 'bad', 'done'} + + # Test choice labeling + assert smg.has_choice_labeling() + assert smg.choice_labeling.get_labels() == {'a', 'b', 'c', 'd'} + + # Test state player indications + assert smg.get_state_player_indications() == [0, 1, 2, 0, 2, 1, 0] + for state in range(smg.nr_states): + assert smg.get_player_of_state(state) == state_player_indications[state]