From 9a7da6f2bd31d60cb9265d954ce97097aafbd3bd Mon Sep 17 00:00:00 2001 From: Roman Andriushchenko Date: Thu, 14 Dec 2023 13:53:13 +0100 Subject: [PATCH] delegate synthesizer selection --- paynt/cli.py | 41 ++++------------------- paynt/synthesizer/policy_tree.py | 2 +- paynt/synthesizer/synthesizer.py | 38 +++++++++++++++++++-- paynt/synthesizer/synthesizer_onebyone.py | 4 +-- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/paynt/cli.py b/paynt/cli.py index f6a1f8164..7a25be8d0 100644 --- a/paynt/cli.py +++ b/paynt/cli.py @@ -4,18 +4,10 @@ import paynt.quotient import paynt.quotient.pomdp -import paynt.quotient.mdp_family -import paynt.synthesizer.policy_tree +import paynt.quotient.storm_pomdp_control -from .synthesizer.synthesizer import Synthesizer -from .synthesizer.synthesizer_onebyone import SynthesizerOneByOne -from .synthesizer.synthesizer_ar import SynthesizerAR -from .synthesizer.synthesizer_cegis import SynthesizerCEGIS -from .synthesizer.synthesizer_hybrid import SynthesizerHybrid -from .synthesizer.synthesizer_pomdp import SynthesizerPOMDP -from .synthesizer.synthesizer_multicore_ar import SynthesizerMultiCoreAR - -from .quotient.storm_pomdp_control import StormPOMDPControl +import paynt.synthesizer.synthesizer +import paynt.synthesizer.synthesizer_cegis import click import sys @@ -141,9 +133,9 @@ def paynt_run( logger.info("This is Paynt version {}.".format(version())) # set CLI parameters - Synthesizer.incomplete_search = incomplete_search + paynt.synthesizer.synthesizer.Synthesizer.incomplete_search = incomplete_search paynt.quotient.quotient.Quotient.compute_expected_visits = not disable_expected_visits - SynthesizerCEGIS.conflict_generator_type = ce_generator + paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = pomdp_memory_size paynt.quotient.pomdp.PomdpQuotient.export_optimal_result = fsc_export_result paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware @@ -157,7 +149,7 @@ def paynt_run( storm_control = None if storm_pomdp: - storm_control = StormPOMDPControl() + storm_control = paynt.quotient.storm_pomdp_control.StormPOMDPControl() storm_control.storm_options = storm_options if get_storm_result is not None: storm_control.get_result = get_storm_result @@ -168,27 +160,8 @@ def paynt_run( storm_control.export_fsc_storm = export_fsc_storm storm_control.export_fsc_paynt = export_fsc_paynt - if isinstance(quotient, paynt.quotient.pomdp_family.PomdpFamilyQuotient): - logger.info("nothing to do with the POMDP sketch, aborting...") - exit(0) - # choose the synthesis method and run the corresponding synthesizer - if isinstance(quotient, paynt.quotient.pomdp.PomdpQuotient) and fsc_synthesis: - synthesizer = SynthesizerPOMDP(quotient, method, storm_control) - elif isinstance(quotient, paynt.quotient.mdp_family.MdpFamilyQuotient): - synthesizer = paynt.synthesizer.policy_tree.SynthesizerPolicyTree(quotient) - elif method == "onebyone": - synthesizer = SynthesizerOneByOne(quotient) - elif method == "ar": - synthesizer = SynthesizerAR(quotient) - elif method == "cegis": - synthesizer = SynthesizerCEGIS(quotient) - elif method == "hybrid": - synthesizer = SynthesizerHybrid(quotient) - elif method == "ar_multicore": - synthesizer = SynthesizerMultiCoreAR(quotient) - else: - pass + synthesizer = paynt.synthesizer.synthesizer.Synthesizer.choose_synthesizer(quotient, method, fsc_synthesis, storm_control) if storm_pomdp: if prune_storm: diff --git a/paynt/synthesizer/policy_tree.py b/paynt/synthesizer/policy_tree.py index a6baff8bc..babbb4a78 100644 --- a/paynt/synthesizer/policy_tree.py +++ b/paynt/synthesizer/policy_tree.py @@ -738,7 +738,7 @@ def synthesize_policy_tree(self, family): prop = self.quotient.get_property() game_solver = self.quotient.build_game_abstraction_solver(prop) policy_tree = PolicyTree(family) - self.create_action_coloring() + # self.create_action_coloring() if False: self.quotient.build(policy_tree.root.family) diff --git a/paynt/synthesizer/synthesizer.py b/paynt/synthesizer/synthesizer.py index d6218320c..1cb6d35fb 100644 --- a/paynt/synthesizer/synthesizer.py +++ b/paynt/synthesizer/synthesizer.py @@ -1,4 +1,4 @@ -from .statistic import Statistic +import paynt.synthesizer.statistic import logging logger = logging.getLogger(__name__) @@ -8,10 +8,44 @@ class Synthesizer: # if True, some subfamilies can be discarded and some holes can be generalized incomplete_search = False + + @staticmethod + def choose_synthesizer(quotient, method, fsc_synthesis, storm_control): + + # hiding imports here to avoid mutual top-level imports + import paynt.quotient.pomdp + import paynt.quotient.mdp_family + import paynt.synthesizer.synthesizer_onebyone + import paynt.synthesizer.synthesizer_ar + import paynt.synthesizer.synthesizer_cegis + import paynt.synthesizer.synthesizer_hybrid + import paynt.synthesizer.synthesizer_multicore_ar + import paynt.synthesizer.synthesizer_pomdp + import paynt.synthesizer.policy_tree + + if isinstance(quotient, paynt.quotient.pomdp_family.PomdpFamilyQuotient): + logger.info("nothing to do with the POMDP sketch, aborting...") + exit(0) + if isinstance(quotient, paynt.quotient.pomdp.PomdpQuotient) and fsc_synthesis: + return paynt.synthesizer.synthesizer_pomdp.SynthesizerPOMDP(quotient, method, storm_control) + if isinstance(quotient, paynt.quotient.mdp_family.MdpFamilyQuotient): + return paynt.synthesizer.policy_tree.SynthesizerPolicyTree(quotient) + if method == "onebyone": + return paynt.synthesizer.synthesizer_onebyone.SynthesizerOneByOne(quotient) + if method == "ar": + return paynt.synthesizer.synthesizer_ar.SynthesizerAR(quotient) + if method == "cegis": + return paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS(quotient) + if method == "hybrid": + return paynt.synthesizer.synthesizer_hybrid.SynthesizerHybrid(quotient) + if method == "ar_multicore": + return paynt.synthesizer.synthesizer_multicore_ar.SynthesizerMultiCoreAR(quotient) + raise ValueError("invalid method name") + def __init__(self, quotient): self.quotient = quotient - self.stat = Statistic(self) + self.stat = paynt.synthesizer.statistic.Statistic(self) self.explored = 0 @property diff --git a/paynt/synthesizer/synthesizer_onebyone.py b/paynt/synthesizer/synthesizer_onebyone.py index 47359e6b8..acd43c88c 100644 --- a/paynt/synthesizer/synthesizer_onebyone.py +++ b/paynt/synthesizer/synthesizer_onebyone.py @@ -1,10 +1,10 @@ -from .synthesizer import Synthesizer +import paynt.synthesizer.synthesizer import logging logger = logging.getLogger(__name__) -class SynthesizerOneByOne(Synthesizer): +class SynthesizerOneByOne(paynt.synthesizer.synthesizer.Synthesizer): @property def method_name(self):