Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Otu update #190

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
9 changes: 6 additions & 3 deletions lib/stormpy/pomdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,18 @@ def create_nondeterminstic_belief_tracker(model, reduction_timeout, track_timeou
return pomdp.NondeterministicBeliefTrackerDoubleSparse(model, opts)


def create_observation_trace_unfolder(model, risk_assessment, expr_manager):
def create_observation_trace_unfolder(model, risk_assessment, expr_manager, rejection_sampling = True):
"""

:param model:
:param risk_assessment:
:param expr_manager:
:param rejection_sampling:
:return:
"""
options = pomdp.ObservationTraceUnfolderOptions()
options.rejection_sampling = rejection_sampling
if model.is_exact:
return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager)
return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager, options)
else:
return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager)
return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager, options)
19 changes: 13 additions & 6 deletions src/pomdp/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ std::shared_ptr<storm::models::sparse::Model<storm::RationalFunction>> apply_unk
}

template<typename ValueType>
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> unfold_trace(storm::models::sparse::Pomdp<ValueType> const& pomdp, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager, std::vector<uint32_t> const& observationTrace, std::vector<ValueType> const& riskDef ) {
storm::pomdp::ObservationTraceUnfolder<ValueType> transformer(pomdp, exprManager);
return transformer.transform(observationTrace, riskDef);
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> unfold_trace(storm::models::sparse::Pomdp<ValueType> const& pomdp, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager, std::vector<uint32_t> const& observationTrace, std::vector<ValueType> const& riskDef, bool rejectionSampling=true) {
storm::pomdp::ObservationTraceUnfolderOptions options = storm::pomdp::ObservationTraceUnfolderOptions();
options.rejectionSampling = rejectionSampling;
storm::pomdp::ObservationTraceUnfolder<ValueType> transformer(pomdp, riskDef, exprManager, options);
return transformer.transform(observationTrace);
}

// STANDARD, SIMPLE_LINEAR, SIMPLE_LINEAR_INVERSE, SIMPLE_LOG, FULL
Expand All @@ -47,6 +49,11 @@ void define_transformations_nt(py::module &m) {
.value("full", storm::transformer::PomdpFscApplicationMode::FULL)
;

py::class_<storm::pomdp::ObservationTraceUnfolderOptions> unfolderOptions(m, "ObservationTraceUnfolderOptions", "Options for the ObservationTraceUnfolder");
unfolderOptions.def(py::init<>());
unfolderOptions.def_readwrite("rejection_sampling", &storm::pomdp::ObservationTraceUnfolderOptions::rejectionSampling);


}

template<typename ValueType>
Expand All @@ -55,12 +62,12 @@ void define_transformations(py::module& m, std::string const& vtSuffix) {
m.def(("_unfold_memory_" + vtSuffix).c_str(), &unfold_memory<ValueType>, "Unfold memory into a POMDP", py::arg("pomdp"), py::arg("memorystructure"), py::arg("memorylabels") = false, py::arg("keep_state_valuations")=false);
m.def(("_make_simple_"+ vtSuffix).c_str(), &make_simple<ValueType>, "Make POMDP simple", py::arg("pomdp"), py::arg("keep_state_valuations")=false);
m.def(("_apply_unknown_fsc_" + vtSuffix).c_str(), &apply_unknown_fsc<ValueType>, "Apply unknown FSC",py::arg("pomdp"), py::arg("application_mode")=storm::transformer::PomdpFscApplicationMode::SIMPLE_LINEAR);
//m.def(("_unfold_trace_" + vtSuffix).c_str(), &unfold_trace<ValueType>, "Unfold observed trace", py::arg("pomdp"), py::arg("expression_manager"),py::arg("observation_trace"), py::arg("risk_definition"));

py::class_<storm::pomdp::ObservationTraceUnfolder<ValueType>> unfolder(m, ("ObservationTraceUnfolder" + vtSuffix).c_str(), "Unfolds observation traces in models");
unfolder.def(py::init<storm::models::sparse::Pomdp<ValueType> const&, std::vector<ValueType> const&, std::shared_ptr<storm::expressions::ExpressionManager>&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager"));
unfolder.def(py::init<storm::models::sparse::Pomdp<ValueType> const&, std::vector<ValueType> const&, std::shared_ptr<storm::expressions::ExpressionManager>&, storm::pomdp::ObservationTraceUnfolderOptions const&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager"), py::arg("options"));
unfolder.def("is_rejection_sampling_set", &storm::pomdp::ObservationTraceUnfolder<ValueType>::isRejectionSamplingSet);
unfolder.def("transform", &storm::pomdp::ObservationTraceUnfolder<ValueType>::transform, py::arg("trace"));
unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder<ValueType>::extend, py::arg("new_observation"));
unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder<ValueType>::extend, py::arg("new_observations"));
unfolder.def("reset", &storm::pomdp::ObservationTraceUnfolder<ValueType>::reset, py::arg("new_observation"));
}

Expand Down