Skip to content

Commit

Permalink
Refactor locals
Browse files Browse the repository at this point in the history
  • Loading branch information
lmpick committed Jun 7, 2022
1 parent 0b83027 commit e38ac0b
Show file tree
Hide file tree
Showing 7 changed files with 655 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pcontainment.runtime.machine.eventhandlers.EventHandlerReturnReason;

import java.util.*;
import java.util.stream.Collectors;

public class Checker {
private final Context ctx = new Context();
Expand All @@ -21,8 +20,10 @@ public class Checker {
private final List<Message> concreteSends = new ArrayList<>();
private IntExpr currentState = ctx.mkIntConst("state_-1");
private IntExpr nextState = ctx.mkIntConst("state_0");
private Locals locals = new Locals();
private Locals locals = new Locals(this);
private int localMergeCount = 0;
@Getter
private Model lastModel = null;

public Expr<?> getExprFor(PValue<?> value) {
if (value instanceof PInt) {
Expand Down Expand Up @@ -66,14 +67,14 @@ public void declLocal(String name, Object default_value) {
if (default_value instanceof PValue) {
PValue<?> value = (PValue<?>) default_value;
val = getExprFor(value);
locals.put(name, getConstWithSameType(val, name + "_" + depth));
locals.assign(name, getConstWithSameType(val, name + "_" + depth));
solver.add(mkEq(val, locals.get(name)));
} else if (default_value instanceof Expr) {
val = (Expr<?>) default_value;
locals.put(name, getConstWithSameType(val, name + "_" + depth));
locals.assign(name, getConstWithSameType(val, name + "_" + depth));
solver.add(mkEq(val, locals.get(name)));
} else {
locals.put(name, mkInt(-1));
locals.assign(name, mkInt(-1));
}
}

Expand All @@ -86,7 +87,7 @@ private Expr<?> getNextLocal(String name) {
}

private Expr<?> getLocal(String name, int depth) {
if (!locals.containsKey(name))
if (!locals.contains(name))
throw new RuntimeException("Tried to access undeclared local: " + name + "");
if (depth != getDepth()) {
if (locals.get(name) == null) return null;
Expand All @@ -97,7 +98,7 @@ private Expr<?> getLocal(String name, int depth) {

private List<Expr<?>> getNextLocals() {
List<Expr<?>> nextLocals = new ArrayList<>();
for (String localName : locals.keySet()) {
for (String localName : locals.symbolicVarSet()) {
//System.out.println(localName + ": " + getNextLocal(localName));
nextLocals.add(getNextLocal(localName));
}
Expand Down Expand Up @@ -169,8 +170,12 @@ public <T extends Sort> Expr<T> mkGet(SeqExpr<T> seq, IntExpr idx) {
return ctx.mkNth(seq, idx);
}

public <T extends Sort> SeqExpr<T> mkAdd(SeqExpr<T> seq, Expr<T> toAdd) {
return ctx.mkConcat(seq, ctx.mkUnit(toAdd));
//public <T extends Sort> SeqExpr<T> mkAdd(SeqExpr<T> seq, Expr<T> toAdd) {
// return ctx.mkConcat(seq, ctx.mkUnit(toAdd));
//}

public <T extends Sort> SeqExpr<T> mkAdd(SeqExpr<T> seq, Expr<?> toAdd) {
return ctx.mkConcat(seq, (SeqExpr<T>) ctx.mkUnit(toAdd));
}

public <T extends Sort> SeqExpr<T> mkSubseq(SeqExpr<T> seq, Expr<IntSort> idx) {
Expand All @@ -181,12 +186,13 @@ public <T extends Sort> BoolExpr mkContains(SeqExpr<T> seq, Expr<T> val) {
return ctx.mkContains(seq, ctx.mkUnit(val));
}

public ArrayExpr<IntSort, IntSort> mkMap(String name) {
return ctx.mkArrayConst(name, ctx.getIntSort(), ctx.getIntSort());

public ArrayExpr<IntSort, IntSort> mkMap() {
return mkMap(ctx.getIntSort(), mkInt(0));
}

public <K extends Sort, V extends Sort> ArrayExpr<K, V> mkMap(String name, K keySort, V valSort) {
return ctx.mkArrayConst(name, keySort, valSort);
public <K extends Sort, V extends Sort> ArrayExpr<K, V> mkMap(K keySort, Expr<V> defaultVal) {
return ctx.mkConstArray(keySort, defaultVal);
}

public <K extends Sort, V extends Sort> ArrayExpr<K, V> mkAdd(ArrayExpr<K, V> map, Expr<K> key, Expr<V> val) {
Expand Down Expand Up @@ -253,8 +259,8 @@ public void started() {
depth++;
currentState = nextState;
nextState = ctx.mkIntConst("state_" + (depth + 1));
for (String k : locals.keySet()) {
locals.put(k, getConstWithSameType(locals.get(k), k + "_" + depth));
for (String k : locals.symbolicVarSet()) {
locals.assign(k, getConstWithSameType(locals.get(k), k + "_" + depth));
}
}

Expand All @@ -266,8 +272,8 @@ public void nextDepth() {
sendTgtIds.clear();
payloads.clear();
concreteSends.clear();
for (String k : locals.keySet()) {
locals.put(k, getConstWithSameType(locals.get(k), k + "_" + depth));
for (String k : locals.symbolicVarSet()) {
locals.assign(k, getConstWithSameType(locals.get(k), k + "_" + depth));
}
}

Expand Down Expand Up @@ -355,24 +361,13 @@ public void noMoreSends() {
}
}

public void check() {
BoolExpr[] asserts = solver.getAssertions();
/*
for (BoolExpr asst : asserts) {
System.out.println(asst.simplify().toString());
}
*/
Status res = solver.check();
if (res != Status.SATISFIABLE) {
throw new RuntimeException("Trace containment check failed: " + res.name());
}
Model m = solver.getModel();
FuncDecl<?>[] decls = m.getConstDecls();
boolean addOld = false;
public void determinize(Model model) {
List<BoolExpr> boolExprs = new ArrayList<>();
Expr<?> interp = m.getConstInterp(nextState);
if (interp != null) {
BoolExpr eq = ctx.mkEq(nextState, interp);
boolean addOld = false;

Expr<?> stateInterp = model.getConstInterp(nextState);
if (stateInterp != null) {
BoolExpr eq = ctx.mkEq(nextState, stateInterp);
if (solver.check(ctx.mkNot(eq)) != Status.SATISFIABLE) {
//System.out.println("Value for state " + nextState + " is unique!! " + interp.toString());
boolExprs.add(eq);
Expand All @@ -383,39 +378,21 @@ public void check() {
} else {
addOld = true;
}
for (Expr<?> e : getNextLocals()) {
if (e == null) continue;
interp = m.getConstInterp(e);
if (interp == null) continue;
BoolExpr eq = ctx.mkEq(e, interp);

Interpretation interp = locals.getInterpretation(model);
for (String varName : locals.symbolicVarSet()) {
Expr<?> nextLocal = getNextLocal(varName);
Expr<?> interpretation = interp.getInterpretation(varName);
if (interpretation == null) continue;
BoolExpr eq = ctx.mkEq(nextLocal, interpretation);
if (solver.check(ctx.mkNot(eq)) != Status.SATISFIABLE) {
//System.out.println("Value for " + e.toString() + " is unique!!");
boolExprs.add(eq);
} else {
addOld = true;
}
}
/*
for (BoolExpr hasSendPred : hasSendPreds) {
BoolExpr eq = ctx.mkEq(hasSendPred, m.getConstInterp(hasSendPred));
if (solver.check(ctx.mkNot(eq)) != Status.SATISFIABLE) {
boolExprs.add(eq);
} else {
addOld = true;
}
}
*/
/*
for (IntExpr sendTgtId : sendTgtIds) {
BoolExpr eq = ctx.mkEq(sendTgtId, m.getConstInterp(sendTgtId));
if (solver.check(ctx.mkNot(eq)) != Status.SATISFIABLE) {
boolExprs.add(eq);
} else {
addOld = true;
}
}
*/
//solver = ctx.mkSolver();
BoolExpr[] asserts = solver.getAssertions();
if (addOld) {
solver.reset();
for (BoolExpr a : asserts) {
Expand All @@ -429,6 +406,15 @@ public void check() {
}
}

public void check() {
Status res = solver.check();
if (res != Status.SATISFIABLE) {
throw new RuntimeException("Trace containment check failed: " + res.name());
}
Model m = solver.getModel();
determinize(m);
}

private IntExpr getCurrentState(int n) {
if (n == 0) return currentState;
return ctx.mkIntConst("state_" + depth + "_" + n);
Expand All @@ -455,13 +441,13 @@ private Pair<BoolExpr, Locals> encodeOutcomesToCompletion(int sends, int states,
}

private Pair<BoolExpr, Locals> frameRule(Locals locals) {
Locals newLocals = new Locals();
BoolExpr[] eqs = new BoolExpr[locals.size()];
Locals newLocals = new Locals(this);
BoolExpr[] eqs = new BoolExpr[locals.symbolicVarSet().size()];
int i = 0;
for (Map.Entry<String, Expr<?>> entry : locals.entrySet()) {
Expr<?> nextLocal = getNextLocal(entry.getKey());
eqs[i] = ctx.mkEq(entry.getValue(), nextLocal);
newLocals.put(entry.getKey(), nextLocal);
for (String varName : locals.symbolicVarSet()) {
Expr<?> nextLocal = getNextLocal(varName);
eqs[i] = ctx.mkEq(locals.get(varName), nextLocal);
newLocals.assign(varName, nextLocal);
i++;
}
return new Pair<>(ctx.mkAnd(eqs), newLocals);
Expand All @@ -487,55 +473,50 @@ private Pair<BoolExpr, Locals> encodeOutcomesToCompletion(int sends, int states,
}

private Locals mergeLocals(Set<Pair<BoolExpr, Locals>> localsSet) {
Locals mergedLocalMap= new Locals();
for (Map.Entry<String, Expr<?>> entry: locals.entrySet()) {
String key = entry.getKey();
String mergedName = key + "_" + depth + "_merge_" + localMergeCount;
Expr<?> mergedVal = getConstWithSameType(entry.getValue(), mergedName);
mergedLocalMap.put(key, mergedVal);
BoolExpr mergedEncoding = ctx.mkTrue();
Locals mergedLocalMap= new Locals(this);
locals.convertConcreteToSymbolic();
// unflatten sequences, sets, maps as needed
Set<String> unflatten = new HashSet<>();
for (String seq : locals.seqSet()) {
for (Pair<BoolExpr, Locals> locals : localsSet) {
if (locals.second.containsKey(key)) {
mergedEncoding = mkOr(mergedEncoding,
mkAnd(locals.first, mkEq(mergedVal, locals.second.get(key))));
if (locals.second.symbolicVarSet().contains(seq)) {
unflatten.add(seq);
break;
}
}
solver.add(mergedEncoding);
}
/*
// sequences
for (DeterministicSeq<Expr<?>> entry : locals.getDetSeqs()) {
String name = entry.getName();
boolean deterministicSize = true;
int size = -1;
for (String seq : locals.setSet()) {
for (Pair<BoolExpr, Locals> locals : localsSet) {
if (locals.second.hasSeq(name)) {
if (size == -1) {
size = locals.second.getSeq(name).size();
} else if (size != locals.second.getSeq(name).size()) {
deterministicSize = false;
}
} else {
deterministicSize = false;
if (locals.second.symbolicVarSet().contains(seq)) {
unflatten.add(seq);
break;
}
}
if (!deterministicSize) {
// TODO: convert to symbolic
throw new RuntimeException("Sequence " + name + " is not deterministic!");
}
for (String seq : locals.mapSet()) {
for (Pair<BoolExpr, Locals> locals : localsSet) {
if (locals.second.symbolicVarSet().contains(seq)) {
unflatten.add(seq);
break;
}
}
for (int i = 0; i < size; i++) {
String mergedName = name + "_idx_" + i + "_" + depth + "_merge_" + localMergeCount;
Expr<?> mergedVal = getConstWithSameType(entry.get(i), mergedName);
mergedLocalMap.getSeq(name).add(i, mergedVal);
BoolExpr mergedEncoding = ctx.mkTrue();
for (Pair<BoolExpr, Locals> locals : localsSet) {
}
for (String toUnflatten : unflatten) {
locals.unflatten(toUnflatten);
}
for (String varName: locals.symbolicVarSet()) {
String mergedName = varName + "_" + depth + "_merge_" + localMergeCount;
Expr<?> mergedVal = getConstWithSameType(locals.get(varName), mergedName);
mergedLocalMap.assign(varName, mergedVal);
BoolExpr mergedEncoding = ctx.mkTrue();
for (Pair<BoolExpr, Locals> locals : localsSet) {
if (locals.second.contains(varName)) {
mergedEncoding = mkOr(mergedEncoding,
mkAnd(locals.first, mkEq(mergedVal, locals.second.get(name, i))));
mkAnd(locals.first, mkEq(mergedVal, locals.second.get(varName))));
}
solver.add(mergedEncoding);
}
solver.add(mergedEncoding);
}
*/
localMergeCount++;
return mergedLocalMap;
}
Expand All @@ -556,33 +537,24 @@ private Locals mergeLocals(Set<Pair<BoolExpr, Locals>> localsSet) {
private Pair<BoolExpr, Locals> encodeOutcomesToCompletion(int sends, int states, Locals locals,
Machine target, EventHandlerReturnReason.Goto goTo) {
Map<BoolExpr, Integer> sendCounts = new HashMap<>();
Locals mergedLocals = new Locals();
for (Map.Entry<String, Expr<?>> entry: locals.entrySet()) {
String key = entry.getKey();
String mergedName = key + "_" + depth + "_" + states;
mergedLocals.put(key, getConstWithSameType(entry.getValue(), mergedName));
}
Set<Pair<BoolExpr, Locals>> localsToMerge = new HashSet<>();
for (State state : target.getStates()) {
BoolExpr stateExit = ctx.mkEq(getCurrentState(states), ctx.mkInt(state.getId()));
Map<BoolExpr, Pair<Integer, Locals>> exitRes =
state.getExitEncoding(sends, locals, this, target);
BoolExpr branches = ctx.mkFalse();
//BoolExpr branches = ctx.mkFalse();
for (Map.Entry<BoolExpr, Pair<Integer, Locals>> branch : exitRes.entrySet()) {
sendCounts.put(ctx.mkAnd(stateExit, branch.getKey()), branch.getValue().first);
BoolExpr merge = ctx.mkTrue();
for (Map.Entry<String, Expr<?>> entry : branch.getValue().second.entrySet()) {
merge = ctx.mkAnd(merge, ctx.mkEq(entry.getValue(),
mergedLocals.get(entry.getKey())));
}
branches = ctx.mkOr(branches, ctx.mkAnd(branch.getKey(), merge));
localsToMerge.add(new Pair<>(branch.getKey(), branch.getValue().second));
}
}
Locals mergedLocals = mergeLocals(localsToMerge);
states++;
BoolExpr entries = ctx.mkFalse();
System.out.println("state update for " + getCurrentState(states).toString());
System.out.println("equal to " + goTo.getGoTo().getId());
BoolExpr stateUpdate = ctx.mkEq(getCurrentState(states), ctx.mkInt(goTo.getGoTo().getId()));
Set<Pair<BoolExpr, Locals>> localsToMerge = new HashSet<>();
localsToMerge = new HashSet<>();
for (Map.Entry<BoolExpr, Integer> branch : sendCounts.entrySet()) {
Pair<BoolExpr, Locals> thisEntry = runEncoding(target, states,
goTo.getGoTo().getEntryEncoding(branch.getValue(),this, mergedLocals,
Expand Down
Loading

0 comments on commit e38ac0b

Please sign in to comment.