Skip to content

Commit

Permalink
Add bindings for the Smg class (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGreatfpmK authored Jul 14, 2024
1 parent 7dde9a8 commit b6413db
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
15 changes: 14 additions & 1 deletion src/storage/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "storm/models/sparse/Pomdp.h"
#include "storm/models/sparse/Ctmc.h"
#include "storm/models/sparse/MarkovAutomaton.h"
#include "storm/models/sparse/Smg.h"
#include "storm/models/sparse/StandardRewardModel.h"
#include "storm/models/symbolic/Model.h"
#include "storm/models/symbolic/Dtmc.h"
Expand Down Expand Up @@ -39,6 +40,7 @@ template<typename ValueType> using SparseMdp = storm::models::sparse::Mdp<ValueT
template<typename ValueType> using SparsePomdp = storm::models::sparse::Pomdp<ValueType>;
template<typename ValueType> using SparseCtmc = storm::models::sparse::Ctmc<ValueType>;
template<typename ValueType> using SparseMarkovAutomaton = storm::models::sparse::MarkovAutomaton<ValueType>;
template<typename ValueType> using SparseSmg = storm::models::sparse::Smg<ValueType>;
template<typename ValueType> using SparseRewardModel = storm::models::sparse::StandardRewardModel<ValueType>;

template<storm::dd::DdType DdType, typename ValueType> using SymbolicModel = storm::models::symbolic::Model<DdType, ValueType>;
Expand Down Expand Up @@ -114,6 +116,7 @@ void define_model(py::module& m) {
.value("POMDP", storm::models::ModelType::Pomdp)
.value("CTMC", storm::models::ModelType::Ctmc)
.value("MA", storm::models::ModelType::MarkovAutomaton)
.value("SMG", storm::models::ModelType::Smg)
;

// ModelBase
Expand Down Expand Up @@ -169,7 +172,10 @@ void define_model(py::module& m) {
.def("_as_sparse_pma", [](ModelBase &modelbase) {
return modelbase.as<SparseMarkovAutomaton<RationalFunction>>();
}, "Get model as sparse pMA")
.def("_as_symbolic_dtmc", [](ModelBase &modelbase) {
.def("_as_sparse_smg", [](ModelBase &modelbase) {
return modelbase.as<SparseSmg<double>>();
}, "Get model as sparse SMG")
.def("_as_symbolic_dtmc", [](ModelBase &modelbase) {
return modelbase.as<SymbolicDtmc<storm::dd::DdType::Sylvan, double>>();
}, "Get model as symbolic DTMC")
.def("_as_symbolic_pdtmc", [](ModelBase &modelbase) {
Expand Down Expand Up @@ -277,6 +283,13 @@ void define_sparse_model(py::module& m, std::string const& vtSuffix) {
.def("convert_to_ctmc", &SparseMarkovAutomaton<ValueType>::convertToCtmc, "Convert the MA into a CTMC.")
;

py::class_<SparseSmg<ValueType>, std::shared_ptr<SparseSmg<ValueType>>>(m, ("Sparse" + vtSuffix + "Smg").c_str(), "SMG in sparse representation", nondetModel)
.def(py::init<SparseSmg<ValueType>>(), py::arg("other_model"))
.def(py::init<ModelComponents<ValueType> const&>(), py::arg("components"))
.def("get_state_player_indications", &SparseSmg<ValueType>::getStatePlayerIndications, "Get for each state its corresponding player")
.def("get_player_of_state", &SparseSmg<ValueType>::getPlayerOfState, py::arg("state"), "Get player for the given state")
;

py::class_<SparseRewardModel<ValueType>>(m, ("Sparse" + vtSuffix + "RewardModel").c_str(), "Reward structure for sparse models")
.def(py::init<std::optional<std::vector<ValueType>> const&, std::optional<std::vector<ValueType>> const&,
std::optional<storm::storage::SparseMatrix<ValueType>> const&>(), py::arg("optional_state_reward_vector") = std::nullopt,
Expand Down
5 changes: 4 additions & 1 deletion src/storage/model_components.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void define_sparse_model_components(py::module& m, std::string const& vtSuffix)
.def_readwrite("markovian_states", &SparseModelComponents<ValueType>::markovianStates, "A list that stores which states are Markovian (only for Markov Automata)")

// Stochastic two player game specific components:
.def_readwrite("player1_matrix", &SparseModelComponents<ValueType>::observabilityClasses, "Matrix of player 1 choices (needed for stochastic two player games")
.def_readwrite("player1_matrix", &SparseModelComponents<ValueType>::player1Matrix, "Matrix of player 1 choices (needed for stochastic two player games")

// Stochastic multiplayer game specific components:
.def_readwrite("state_player_indications", &SparseModelComponents<ValueType>::statePlayerIndications, "The vector mapping states to player indices")
;

}
Expand Down
94 changes: 94 additions & 0 deletions tests/storage/test_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,97 @@ def create_number(num):

# Test choice labeling
assert not dtmc.has_choice_labeling()


def test_build_smg(self):
nr_states = 7
nr_choices = 10

# Build transition matrix
builder = stormpy.SparseMatrixBuilder(rows=0, columns=0, entries=0, force_dimensions=False,
has_custom_row_grouping=True, row_groups=0)

# Row group, state 0
builder.new_row_group(0)
builder.add_next_value(0, 0, 0.5)
builder.add_next_value(0, 1, 0.5)
builder.add_next_value(1, 1, 0.2)
builder.add_next_value(1, 2, 0.8)
# Row group, state 1
builder.new_row_group(2)
builder.add_next_value(2, 2, 1)
builder.add_next_value(3, 3, 0.5)
builder.add_next_value(3, 6, 0.5)
# Row group, state 2
builder.new_row_group(4)
builder.add_next_value(4, 3, 0.5)
builder.add_next_value(4, 4, 0.5)
builder.add_next_value(5, 0, 1)
# Row group, state 3
builder.new_row_group(6)
builder.add_next_value(6, 5, 1)
# Row group, state 4
builder.new_row_group(7)
builder.add_next_value(7, 5, 1)
# Row group, state 5, goal state
builder.new_row_group(8)
builder.add_next_value(8, 5, 1)
# Row group, state 6, deadlock state
builder.new_row_group(9)
builder.add_next_value(9, 6, 1)

transition_matrix = builder.build(nr_choices, nr_states)

# State labeling
state_labeling = stormpy.storage.StateLabeling(nr_states)
labels = {'init_p1', 'bad', 'done'}
for label in labels:
state_labeling.add_label(label)
state_labeling.add_label_to_state('init_p1', 0)
state_labeling.add_label_to_state('done', 5)
state_labeling.add_label_to_state('bad', 6)

# Choice labeling
choice_labeling = stormpy.storage.ChoiceLabeling(nr_choices)
choice_labels = {'a', 'b', 'c', 'd'}
for label in choice_labels:
choice_labeling.add_label(label)
choice_labeling.add_label_to_choice('a', 0)
choice_labeling.add_label_to_choice('b', 1)
choice_labeling.add_label_to_choice('c', 2)
choice_labeling.add_label_to_choice('d', 3)
choice_labeling.add_label_to_choice('a', 4)
choice_labeling.add_label_to_choice('b', 5)

# State player indications
state_player_indications = [0, 1, 2, 0, 2, 1, 0]

components = stormpy.SparseModelComponents(transition_matrix=transition_matrix, state_labeling=state_labeling)
components.choice_labeling = choice_labeling
components.state_player_indications = state_player_indications

# Build SMG
smg = stormpy.storage.SparseSmg(components)

assert type(smg) is stormpy.SparseSmg
assert smg.model_type == stormpy.storage.ModelType.SMG

# Test transition matrix
assert smg.nr_choices == nr_choices
assert smg.nr_states == nr_states
assert smg.nr_transitions == 14
assert smg.transition_matrix.nr_entries == smg.nr_transitions
for state in smg.states:
assert len(state.actions) <= 2

# Test state labeling
assert smg.labeling.get_labels() == {'init_p1', 'bad', 'done'}

# Test choice labeling
assert smg.has_choice_labeling()
assert smg.choice_labeling.get_labels() == {'a', 'b', 'c', 'd'}

# Test state player indications
assert smg.get_state_player_indications() == [0, 1, 2, 0, 2, 1, 0]
for state in range(smg.nr_states):
assert smg.get_player_of_state(state) == state_player_indications[state]

0 comments on commit b6413db

Please sign in to comment.