Skip to content

Commit

Permalink
delegate synthesizer selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Dec 14, 2023
1 parent bfa5b90 commit 9a7da6f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 39 deletions.
41 changes: 7 additions & 34 deletions paynt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@

import paynt.quotient
import paynt.quotient.pomdp
import paynt.quotient.mdp_family
import paynt.synthesizer.policy_tree
import paynt.quotient.storm_pomdp_control

from .synthesizer.synthesizer import Synthesizer
from .synthesizer.synthesizer_onebyone import SynthesizerOneByOne
from .synthesizer.synthesizer_ar import SynthesizerAR
from .synthesizer.synthesizer_cegis import SynthesizerCEGIS
from .synthesizer.synthesizer_hybrid import SynthesizerHybrid
from .synthesizer.synthesizer_pomdp import SynthesizerPOMDP
from .synthesizer.synthesizer_multicore_ar import SynthesizerMultiCoreAR

from .quotient.storm_pomdp_control import StormPOMDPControl
import paynt.synthesizer.synthesizer
import paynt.synthesizer.synthesizer_cegis

import click
import sys
Expand Down Expand Up @@ -141,9 +133,9 @@ def paynt_run(
logger.info("This is Paynt version {}.".format(version()))

# set CLI parameters
Synthesizer.incomplete_search = incomplete_search
paynt.synthesizer.synthesizer.Synthesizer.incomplete_search = incomplete_search
paynt.quotient.quotient.Quotient.compute_expected_visits = not disable_expected_visits
SynthesizerCEGIS.conflict_generator_type = ce_generator
paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator
paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = pomdp_memory_size
paynt.quotient.pomdp.PomdpQuotient.export_optimal_result = fsc_export_result
paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware
Expand All @@ -157,7 +149,7 @@ def paynt_run(

storm_control = None
if storm_pomdp:
storm_control = StormPOMDPControl()
storm_control = paynt.quotient.storm_pomdp_control.StormPOMDPControl()
storm_control.storm_options = storm_options
if get_storm_result is not None:
storm_control.get_result = get_storm_result
Expand All @@ -168,27 +160,8 @@ def paynt_run(
storm_control.export_fsc_storm = export_fsc_storm
storm_control.export_fsc_paynt = export_fsc_paynt

if isinstance(quotient, paynt.quotient.pomdp_family.PomdpFamilyQuotient):
logger.info("nothing to do with the POMDP sketch, aborting...")
exit(0)

# choose the synthesis method and run the corresponding synthesizer
if isinstance(quotient, paynt.quotient.pomdp.PomdpQuotient) and fsc_synthesis:
synthesizer = SynthesizerPOMDP(quotient, method, storm_control)
elif isinstance(quotient, paynt.quotient.mdp_family.MdpFamilyQuotient):
synthesizer = paynt.synthesizer.policy_tree.SynthesizerPolicyTree(quotient)
elif method == "onebyone":
synthesizer = SynthesizerOneByOne(quotient)
elif method == "ar":
synthesizer = SynthesizerAR(quotient)
elif method == "cegis":
synthesizer = SynthesizerCEGIS(quotient)
elif method == "hybrid":
synthesizer = SynthesizerHybrid(quotient)
elif method == "ar_multicore":
synthesizer = SynthesizerMultiCoreAR(quotient)
else:
pass
synthesizer = paynt.synthesizer.synthesizer.Synthesizer.choose_synthesizer(quotient, method, fsc_synthesis, storm_control)

if storm_pomdp:
if prune_storm:
Expand Down
2 changes: 1 addition & 1 deletion paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def synthesize_policy_tree(self, family):
prop = self.quotient.get_property()
game_solver = self.quotient.build_game_abstraction_solver(prop)
policy_tree = PolicyTree(family)
self.create_action_coloring()
# self.create_action_coloring()

if False:
self.quotient.build(policy_tree.root.family)
Expand Down
38 changes: 36 additions & 2 deletions paynt/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .statistic import Statistic
import paynt.synthesizer.statistic

import logging
logger = logging.getLogger(__name__)
Expand All @@ -8,10 +8,44 @@ class Synthesizer:

# if True, some subfamilies can be discarded and some holes can be generalized
incomplete_search = False

@staticmethod
def choose_synthesizer(quotient, method, fsc_synthesis, storm_control):

# hiding imports here to avoid mutual top-level imports
import paynt.quotient.pomdp
import paynt.quotient.mdp_family
import paynt.synthesizer.synthesizer_onebyone
import paynt.synthesizer.synthesizer_ar
import paynt.synthesizer.synthesizer_cegis
import paynt.synthesizer.synthesizer_hybrid
import paynt.synthesizer.synthesizer_multicore_ar
import paynt.synthesizer.synthesizer_pomdp
import paynt.synthesizer.policy_tree

if isinstance(quotient, paynt.quotient.pomdp_family.PomdpFamilyQuotient):
logger.info("nothing to do with the POMDP sketch, aborting...")
exit(0)
if isinstance(quotient, paynt.quotient.pomdp.PomdpQuotient) and fsc_synthesis:
return paynt.synthesizer.synthesizer_pomdp.SynthesizerPOMDP(quotient, method, storm_control)
if isinstance(quotient, paynt.quotient.mdp_family.MdpFamilyQuotient):
return paynt.synthesizer.policy_tree.SynthesizerPolicyTree(quotient)
if method == "onebyone":
return paynt.synthesizer.synthesizer_onebyone.SynthesizerOneByOne(quotient)
if method == "ar":
return paynt.synthesizer.synthesizer_ar.SynthesizerAR(quotient)
if method == "cegis":
return paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS(quotient)
if method == "hybrid":
return paynt.synthesizer.synthesizer_hybrid.SynthesizerHybrid(quotient)
if method == "ar_multicore":
return paynt.synthesizer.synthesizer_multicore_ar.SynthesizerMultiCoreAR(quotient)
raise ValueError("invalid method name")


def __init__(self, quotient):
self.quotient = quotient
self.stat = Statistic(self)
self.stat = paynt.synthesizer.statistic.Statistic(self)
self.explored = 0

@property
Expand Down
4 changes: 2 additions & 2 deletions paynt/synthesizer/synthesizer_onebyone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .synthesizer import Synthesizer
import paynt.synthesizer.synthesizer

import logging
logger = logging.getLogger(__name__)


class SynthesizerOneByOne(Synthesizer):
class SynthesizerOneByOne(paynt.synthesizer.synthesizer.Synthesizer):

@property
def method_name(self):
Expand Down

0 comments on commit 9a7da6f

Please sign in to comment.