Skip to content

Commit

Permalink
Allow including lemmas dynamically into APRProver (#4681)
Browse files Browse the repository at this point in the history
In order to include lemmas dynamically in the Kontrol prover, we need an
option for passing the extra module into the `APRProver` so that it can
be fed into the booster's `add-module` endpoint. This also requires
being able to translate functional rules to Kore in pyk using
`krule_to_kore`. This PR:

- Adds the `extra_module` parameter to the `APRProver` initialization,
which allows passing a new module with lemmas into the prover at
initialization time. It will pass this module on to the `add-module`
endpoint and make sure it's used on all future requests.
- Adds the `symbolic` and `smt-lemma` attributes to the list of
attributes recognized by pyk.
- Renames `KProve.get_claims_modules` to `KProve.parse_modules`, and
renames some of its parameters.
- Adds support to specifying the sort to use in `bool_to_ml_pred` for
converting boolean predicates to matching logic ones.
- Makes adjustments to `_krule_to_kore` to make it support translating
some functional/simplification rules to Kore which were not supported
before (and adds tests):
- Factors out `_krule_att_to_kore` method for converting attributes of
`KRule` to kore.
- Adds a pass to process the rule attributes a bit in `krule_to_kore`
(eg. handle `owise => priority(200)` and `simplification` vs
`priority`).
- Sort the attributes before sending to Kore (for reliable ordering of
the Kore output).
- Refactors the production of Kore `axiom` to differentiate between
semantic and functional rules, and to convert functional rules as
matching-logic `implies` axioms.
  • Loading branch information
ehildenb authored Nov 21, 2024
1 parent 75f108f commit 8aade39
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 65 deletions.
2 changes: 2 additions & 0 deletions pyk/src/pyk/kast/att.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,11 @@ class Atts:
SEQSTRICT: Final = AttKey('seqstrict', type=_ANY)
SORT: Final = AttKey('org.kframework.kore.Sort', type=_ANY)
SOURCE: Final = AttKey('org.kframework.attributes.Source', type=_PATH)
SMTLEMMA: Final = AttKey('smt-lemma', type=_NONE)
STRICT: Final = AttKey('strict', type=_ANY)
SYMBOL: Final = AttKey('symbol', type=_STR)
SYNTAX_MODULE: Final = AttKey('syntaxModule', type=_STR)
SYMBOLIC: Final = AttKey('symbolic', type=OptionalType(_STR))
TERMINALS: Final = AttKey('terminals', type=_STR)
TOKEN: Final = AttKey('token', type=_NONE)
TOTAL: Final = AttKey('total', type=_NONE)
Expand Down
10 changes: 5 additions & 5 deletions pyk/src/pyk/kast/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def fun(term: KInner) -> KInner:
return fun


def bool_to_ml_pred(kast: KInner) -> KInner:
def bool_to_ml_pred(kast: KInner, sort: str | KSort = GENERATED_TOP_CELL) -> KInner:
def _bool_constraint_to_ml(_kast: KInner) -> KInner:
if _kast == TRUE:
return mlTop()
return mlTop(sort=sort)
if _kast == FALSE:
return mlBottom()
return mlEqualsTrue(_kast)
return mlBottom(sort=sort)
return mlEqualsTrue(_kast, sort=sort)

return mlAnd([_bool_constraint_to_ml(cond) for cond in flatten_label('_andBool_', kast)])
return mlAnd([_bool_constraint_to_ml(cond) for cond in flatten_label('_andBool_', kast)], sort=sort)


def ml_pred_to_bool(kast: KInner, unsafe: bool = False) -> KInner:
Expand Down
9 changes: 7 additions & 2 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..prelude.kbool import TRUE
from ..prelude.ml import ML_QUANTIFIERS
from ..utils import FrozenDict, POSet, filter_none, single, unique
from ..utils import FrozenDict, POSet, filter_none, not_none, single, unique
from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt
from .inner import (
KApply,
Expand Down Expand Up @@ -1117,6 +1117,11 @@ def functions(self) -> tuple[KProduction, ...]:
"""Returns the `KProduction` which are function declarations transitively imported by the main module of this definition."""
return tuple(func for module in self.modules for func in module.functions)

@cached_property
def function_labels(self) -> tuple[str, ...]:
"""Returns the label names of all the `KProduction` which are function symbols for all modules in this definition."""
return tuple(not_none(func.klabel).name for func in self.functions)

@cached_property
def constructors(self) -> tuple[KProduction, ...]:
"""Returns the `KProduction` which are constructor declarations transitively imported by the main module of this definition."""
Expand Down Expand Up @@ -1394,7 +1399,7 @@ def _add_ksequence_under_k_productions(_kast: KInner) -> KInner:
return top_down(_add_ksequence_under_k_productions, kast)

def sort_vars(self, kast: KInner, sort: KSort | None = None) -> KInner:
"""Return the original term with all the variables having there sorts added or specialized, failing if recieving conflicting sorts for a given variable."""
"""Return the original term with all the variables having the sorts added or specialized, failing if recieving conflicting sorts for a given variable."""
if type(kast) is KVariable and kast.sort is None and sort is not None:
return kast.let(sort=sort)

Expand Down
172 changes: 135 additions & 37 deletions pyk/src/pyk/konvert/_kast_to_kore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,33 @@
from typing import TYPE_CHECKING

from ..kast import Atts
from ..kast.inner import KApply, KLabel, KSequence, KSort, KToken, KVariable, top_down
from ..kast.manip import bool_to_ml_pred, extract_lhs, extract_rhs, flatten_label
from ..kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KToken, KVariable, top_down
from ..kast.manip import bool_to_ml_pred, extract_lhs, extract_rhs, flatten_label, ml_pred_to_bool, var_occurrences
from ..kast.outer import KRule
from ..kore.prelude import BOOL as KORE_BOOL
from ..kore.prelude import SORT_K
from ..kore.syntax import DV, And, App, Axiom, EVar, Import, MLPattern, MLQuant, Module, Rewrites, SortApp, String, Top
from ..kore.prelude import TRUE as KORE_TRUE
from ..kore.syntax import (
DV,
And,
App,
Axiom,
Equals,
EVar,
Implies,
Import,
MLPattern,
MLQuant,
Module,
Rewrites,
SortApp,
SortVar,
String,
Top,
)
from ..prelude.bytes import BYTES, pretty_bytes_str
from ..prelude.k import K_ITEM, K, inj
from ..prelude.kbool import TRUE
from ..prelude.kbool import BOOL, TRUE, andBool
from ..prelude.ml import mlAnd
from ..prelude.string import STRING, pretty_string
from ._utils import munge
Expand All @@ -21,6 +40,7 @@
from typing import Final

from ..kast import KInner
from ..kast.att import AttEntry
from ..kast.outer import KDefinition, KFlatModule, KImport
from ..kore.syntax import Pattern, Sentence, Sort

Expand Down Expand Up @@ -201,42 +221,101 @@ def krule_to_kore(definition: KDefinition, krule: KRule) -> Axiom:
krule_body = krule.body
krule_lhs_config = extract_lhs(krule_body)
krule_rhs_config = extract_rhs(krule_body)
krule_lhs_constraints = [bool_to_ml_pred(c) for c in flatten_label('_andBool_', krule.requires) if not c == TRUE]
krule_rhs_constraints = [bool_to_ml_pred(c) for c in flatten_label('_andBool_', krule.ensures) if not c == TRUE]
krule_lhs = mlAnd([krule_lhs_config] + krule_lhs_constraints)
krule_rhs = mlAnd([krule_rhs_config] + krule_rhs_constraints)

top_level_kore_sort = SortApp('SortGeneratedTopCell')
top_level_k_sort = KSort('GeneratedTopCell')
# The backend does not like rewrite rules without a precondition
if len(krule_lhs_constraints) > 0:
kore_lhs: Pattern = kast_to_kore(definition, krule_lhs, sort=top_level_k_sort)

is_functional = isinstance(krule_lhs_config, KApply) and krule_lhs_config.label.name in definition.function_labels

top_level_k_sort = KSort('GeneratedTopCell') if not is_functional else definition.sort_strict(krule_lhs_config)
top_level_kore_sort = _ksort_to_kore(top_level_k_sort)

# Do sort inference on the entire rule at once
kast_lhs = mlAnd(
[krule_lhs_config]
+ [
bool_to_ml_pred(constraint, sort=top_level_k_sort)
for constraint in flatten_label('_andBool_', krule.requires)
if not constraint == TRUE
],
sort=top_level_k_sort,
)
kast_rhs = mlAnd(
[krule_rhs_config]
+ [
bool_to_ml_pred(constraint, sort=top_level_k_sort)
for constraint in flatten_label('_andBool_', krule.ensures)
if not constraint == TRUE
],
sort=top_level_k_sort,
)
kast_rule_sorted = definition.sort_vars(KRewrite(kast_lhs, kast_rhs))

kast_lhs_body, *kast_lhs_constraints = flatten_label('#And', extract_lhs(kast_rule_sorted))
kast_rhs_body, *kast_rhs_constraints = flatten_label('#And', extract_rhs(kast_rule_sorted))
kore_lhs_body = kast_to_kore(definition, kast_lhs_body, sort=top_level_k_sort)
kore_rhs_body = kast_to_kore(definition, kast_rhs_body, sort=top_level_k_sort)

axiom_vars: tuple[SortVar, ...] = ()
kore_axiom: Pattern
if not is_functional:
kore_lhs_constraints = [
kast_to_kore(definition, kast_lhs_constraint, sort=top_level_k_sort)
for kast_lhs_constraint in kast_lhs_constraints
]
kore_rhs_constraints = [
kast_to_kore(definition, kast_rhs_constraint, sort=top_level_k_sort)
for kast_rhs_constraint in kast_rhs_constraints
]
kore_lhs_constraint: Pattern = Top(top_level_kore_sort)
if len(kore_lhs_constraints) == 1:
kore_lhs_constraint = kore_lhs_constraints[0]
elif len(kore_lhs_constraints) > 1:
kore_lhs_constraint = And(top_level_kore_sort, kore_lhs_constraints)
kore_lhs = And(top_level_kore_sort, [kore_lhs_body, kore_lhs_constraint])
kore_rhs = (
kore_rhs_body
if not kore_rhs_constraints
else And(top_level_kore_sort, [kore_rhs_body] + kore_rhs_constraints)
)
kore_axiom = Rewrites(sort=top_level_kore_sort, left=kore_lhs, right=kore_rhs)
else:
kore_lhs = And(
top_level_kore_sort,
(
kast_to_kore(definition, krule_lhs, sort=top_level_k_sort),
Top(top_level_kore_sort),
),
axiom_sort = SortVar('R')
axiom_vars = (axiom_sort,)
kast_lhs_constraints_bool = [
ml_pred_to_bool(kast_lhs_constraint) for kast_lhs_constraint in kast_lhs_constraints
]
kore_antecedent = Equals(
KORE_BOOL, axiom_sort, kast_to_kore(definition, andBool(kast_lhs_constraints_bool), sort=BOOL), KORE_TRUE
)
kore_ensures: Pattern = Top(top_level_kore_sort)
if kast_rhs_constraints:
kast_rhs_constraints_bool = [
ml_pred_to_bool(kast_rhs_constraint) for kast_rhs_constraint in kast_rhs_constraints
]
kore_ensures = Equals(
KORE_BOOL,
top_level_kore_sort,
kast_to_kore(definition, andBool(kast_rhs_constraints_bool), sort=BOOL),
KORE_TRUE,
)
kore_consequent = Equals(
top_level_kore_sort, axiom_sort, kore_lhs_body, And(top_level_kore_sort, [kore_rhs_body, kore_ensures])
)
kore_axiom = Implies(axiom_sort, kore_antecedent, kore_consequent)

kore_rhs: Pattern = kast_to_kore(definition, krule_rhs, sort=top_level_k_sort)

prio = krule.priority
attrs = [App(symbol='priority', sorts=(), args=(String(str(prio)),))]
if Atts.LABEL in krule.att:
label = krule.att[Atts.LABEL]
attrs.append(App(symbol='label', sorts=(), args=(String(label),)))
axiom = Axiom(
vars=(),
pattern=Rewrites(
sort=top_level_kore_sort,
left=kore_lhs,
right=kore_rhs,
),
attrs=attrs,
)
return axiom
# Make adjustments to Rule attributes
att = krule.att.discard([Atts.PRODUCTION, Atts.UNIQUE_ID, Atts.SOURCE, Atts.LOCATION])
if Atts.PRIORITY not in att:
if Atts.OWISE in att:
att = att.update([Atts.PRIORITY(200)])
att = att.discard([Atts.OWISE])
elif Atts.SIMPLIFICATION not in att:
att = att.update([Atts.PRIORITY(50)])

attrs = [
_krule_att_to_kore(att_entry, var_occurrences(kast_rule_sorted))
for att_entry in sorted(att.entries(), key=(lambda a: a.key.name))
]

return Axiom(vars=axiom_vars, pattern=kore_axiom, attrs=attrs)


def kflatmodule_to_kore(definition: KDefinition, kflatmodule: KFlatModule) -> Module:
Expand All @@ -249,6 +328,25 @@ def kflatmodule_to_kore(definition: KDefinition, kflatmodule: KFlatModule) -> Mo
return Module(name=kflatmodule.name, sentences=(imports + kore_axioms))


def _krule_att_to_kore(att_entry: AttEntry, kast_rule_vars: dict[str, list[KVariable]]) -> App:
match att_entry.key:
case Atts.LABEL | Atts.PRIORITY | Atts.SIMPLIFICATION:
return App(symbol=att_entry.key.name, sorts=(), args=(String(str(att_entry.value)),))
case Atts.SYMBOLIC | Atts.CONCRETE:
if not att_entry.value:
return App(symbol=att_entry.key.name, sorts=(), args=())
kore_vars = []
for var_name in att_entry.value.split(','):
if var_name not in kast_rule_vars:
raise ValueError(f'Variable in {att_entry.key} not present in rule: {var_name}')
kore_vars.append(_kvariable_to_kore(kast_rule_vars[var_name][0]))
return App(symbol=att_entry.key.name, sorts=(), args=tuple(kore_vars))
case Atts.SMTLEMMA:
return App(symbol=att_entry.key.name, sorts=(), args=())
case _:
raise ValueError(f'Do not know how to convert AttEntry to Kore: {att_entry}')


def _kimport_to_kore(kimport: KImport) -> Import:
return Import(module_name=kimport.name, attrs=())

Expand Down
6 changes: 3 additions & 3 deletions pyk/src/pyk/ktool/claim_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def load_claims(

if not cache_hit:
_LOGGER.info('Generating claim modules')
module_list = self._kprove.get_claim_modules(
spec_file=spec_file,
spec_module_name=spec_module_name,
module_list = self._kprove.parse_modules(
file_path=spec_file,
module_name=spec_module_name,
include_dirs=include_dirs,
md_selector=md_selector,
type_inference_mode=type_inference_mode,
Expand Down
18 changes: 9 additions & 9 deletions pyk/src/pyk/ktool/kprove.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,19 @@ def prove_claim(
depth=depth,
)

def get_claim_modules(
def parse_modules(
self,
spec_file: Path,
spec_module_name: str | None = None,
file_path: Path,
module_name: str | None = None,
include_dirs: Iterable[Path] = (),
md_selector: str | None = None,
type_inference_mode: TypeInferenceMode | None = None,
) -> KFlatModuleList:
with self._temp_file(prefix=f'{spec_file.name}.parsed.json.') as ntf:
with self._temp_file(prefix=f'{file_path.name}.parsed.json.') as ntf:
_kprove(
spec_file=spec_file,
spec_file=file_path,
kompiled_dir=self.definition_dir,
spec_module_name=spec_module_name,
spec_module_name=module_name,
include_dirs=include_dirs,
md_selector=md_selector,
output=KProveOutput.JSON,
Expand All @@ -288,9 +288,9 @@ def get_claim_index(
md_selector: str | None = None,
type_inference_mode: TypeInferenceMode | None = None,
) -> ClaimIndex:
module_list = self.get_claim_modules(
spec_file=spec_file,
spec_module_name=spec_module_name,
module_list = self.parse_modules(
file_path=spec_file,
module_name=spec_module_name,
include_dirs=include_dirs,
md_selector=md_selector,
type_inference_mode=type_inference_mode,
Expand Down
12 changes: 11 additions & 1 deletion pyk/src/pyk/proof/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class APRProver(Prover[APRProof, APRProofStep, APRProofResult]):
direct_subproof_rules: bool
assume_defined: bool
kcfg_explore: KCFGExplore
extra_module: KFlatModule | None

def __init__(
self,
Expand All @@ -725,6 +726,7 @@ def __init__(
fast_check_subsumption: bool = False,
direct_subproof_rules: bool = False,
assume_defined: bool = False,
extra_module: KFlatModule | None = None,
) -> None:

self.kcfg_explore = kcfg_explore
Expand All @@ -736,11 +738,19 @@ def __init__(
self.fast_check_subsumption = fast_check_subsumption
self.direct_subproof_rules = direct_subproof_rules
self.assume_defined = assume_defined
self.extra_module = extra_module

def close(self) -> None:
self.kcfg_explore.cterm_symbolic._kore_client.close()

def init_proof(self, proof: APRProof) -> None:
main_module_name = self.main_module_name
if self.extra_module:
_kore_module = kflatmodule_to_kore(self.kcfg_explore.cterm_symbolic._definition, self.extra_module)
_LOGGER.warning(f'_kore_module: {_kore_module.text}')
self.kcfg_explore.cterm_symbolic._kore_client.add_module(_kore_module, name_as_id=True)
main_module_name = self.extra_module.name

def _inject_module(module_name: str, import_name: str, sentences: list[KRule]) -> None:
_module = KFlatModule(module_name, sentences, [KImport(import_name)])
_kore_module = kflatmodule_to_kore(self.kcfg_explore.cterm_symbolic._definition, _module)
Expand All @@ -759,7 +769,7 @@ def _inject_module(module_name: str, import_name: str, sentences: list[KRule]) -
]
circularity_rule = proof.as_rule(priority=20)

_inject_module(proof.dependencies_module_name, self.main_module_name, dependencies_as_rules)
_inject_module(proof.dependencies_module_name, main_module_name, dependencies_as_rules)
_inject_module(proof.circularities_module_name, proof.dependencies_module_name, [circularity_rule])

for node_id in [proof.init, proof.target]:
Expand Down
Loading

0 comments on commit 8aade39

Please sign in to comment.