Skip to content

Commit

Permalink
Fix a bug with patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
valis committed Feb 29, 2024
1 parent 4317075 commit 8dffb0b
Showing 1 changed file with 144 additions and 112 deletions.
256 changes: 144 additions & 112 deletions meta/src/main/java/org/arend/lib/pattern/PatternUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
import org.arend.ext.core.body.CorePattern;
import org.arend.ext.core.context.CoreBinding;
import org.arend.ext.core.context.CoreParameter;
import org.arend.ext.core.definition.CoreConstructor;
import org.arend.ext.core.definition.CoreDefinition;
import org.arend.ext.core.definition.CoreFunctionDefinition;
import org.arend.ext.core.expr.*;
import org.arend.ext.core.level.LevelSubstitution;
import org.arend.ext.core.ops.NormalizationMode;
import org.arend.ext.reference.ArendRef;
import org.arend.ext.typechecking.ExpressionTypechecker;
import org.arend.ext.typechecking.TypedExpression;
Expand Down Expand Up @@ -266,148 +263,183 @@ public static CoreExpression eval(CoreElimBody body, LevelSubstitution levelSubs
return null;
}


/**
* Checks coverage for a list of patterns with {@code type} as their type.
* @return indices of rows from {@code actualRows} that cover {@code row}, or {@code null} if {@code actualRows} do not cover {@code row}
*/
private static boolean checkCoverage(List<Pair<Integer, CorePattern>> patterns, CoreExpression type, Set<Integer> result) {
for (Pair<Integer, CorePattern> pair : patterns) {
if (pair.proj2.isAbsurd() || pair.proj2.getBinding() != null || pair.proj2.getConstructor() instanceof CoreFunctionDefinition) {
result.add(pair.proj1);
return true;
public static List<Integer> computeCovering(List<? extends List<? extends CorePattern>> actualRows, List<? extends CorePattern> row) {
for (CorePattern pattern : row) {
if (pattern.isAbsurd()) {
return Collections.emptyList();
}
}

type = type.normalize(NormalizationMode.WHNF);

if (patterns.isEmpty()) {
if (!(type instanceof CoreDataCallExpression)) {
return false;
}
List<CoreConstructor> constructors = ((CoreDataCallExpression) type).computeMatchedConstructors();
return constructors != null && constructors.isEmpty();
if (actualRows.isEmpty()) {
return null;
}

boolean isTuple = patterns.get(0).proj2.getConstructor() == null;
for (Pair<Integer, CorePattern> pair : patterns) {
if (isTuple != (pair.proj2.getConstructor() == null)) {
return false;
}
List<Pair<? extends List<? extends CorePattern>, Integer>> rows = new ArrayList<>(actualRows.size());
for (int i = 0; i < actualRows.size(); i++) {
rows.add(new Pair<>(actualRows.get(i), i));
}

if (isTuple) {
CoreParameter parameters;
if (type instanceof CoreSigmaExpression) {
parameters = ((CoreSigmaExpression) type).getParameters();
} else if (type instanceof CoreClassCallExpression) {
parameters = ((CoreClassCallExpression) type).getClassFieldParameters();
} else {
return false;
Set<Integer> indices = new HashSet<>();
if (!computeCovering(rows, row, indices)) {
return null;
}
List<Integer> result = new ArrayList<>(actualRows.size());
for (int i = 0; i < actualRows.size(); i++) {
if (indices.contains(i)) {
result.add(i);
}
return checkCoverage(patterns, parameters, result);
}
return result;
}

List<CoreDataCallExpression.ConstructorWithDataArguments> constructors = type.computeMatchedConstructorsWithDataArguments();
if (constructors == null) {
private static boolean computeCovering(List<Pair<? extends List<? extends CorePattern>, Integer>> actualRows, List<? extends CorePattern> row, Set<Integer> indices) {
if (actualRows.isEmpty()) {
return false;
}

Map<CoreDefinition, List<Pair<Integer, CorePattern>>> map = new HashMap<>();
for (Pair<Integer, CorePattern> pair : patterns) {
map.computeIfAbsent(pair.proj2.getConstructor(), k -> new ArrayList<>()).add(pair);
}

for (CoreDataCallExpression.ConstructorWithDataArguments conCall : constructors) {
List<Pair<Integer, CorePattern>> list = map.get(conCall.getConstructor());
if (list == null || !checkCoverage(list, conCall.getParameters(), result)) {
return false;
}
if (row.isEmpty()) {
indices.add(actualRows.get(0).proj2);
return true;
}

return true;
}
class MyBinding implements CoreBinding {
final CoreExpression type;

private static boolean checkCoverage(List<Pair<Integer, CorePattern>> rows, CoreParameter parameters, Set<Integer> result) {
int numberOfColumns = rows.get(0).proj2.getSubPatterns().size();
for (Pair<Integer, CorePattern> row : rows) {
if (row.proj2.getSubPatterns().size() != numberOfColumns) {
return false;
MyBinding(CoreExpression type) {
this.type = type;
}
}

CoreParameter param = parameters;
for (int i = 0; i < numberOfColumns; i++) {
if (!param.hasNext()) {
return false;
}
List<Pair<Integer, CorePattern>> column = new ArrayList<>(rows.size());
for (Pair<Integer, CorePattern> row : rows) {
column.add(new Pair<>(row.proj1, row.proj2.getSubPatterns().get(i)));
}
if (!checkCoverage(column, param.getTypeExpr(), result)) {
return false;
@Override
public CoreExpression getTypeExpr() {
return type;
}
param = param.getNext();
}

if (param.hasNext()) {
return false;
}
@Override
public CoreReferenceExpression makeReference() {
return null;
}

if (numberOfColumns == 0) {
result.add(rows.get(0).proj1);
@Override
public String getName() {
return null;
}
}
return true;
}

/**
* @return indices of rows from {@code actualRows} that cover {@code row}, or {@code null} if {@code actualRows} do not cover {@code row}
*/
public static List<Integer> computeCovering(List<? extends List<? extends CorePattern>> actualRows, List<? extends CorePattern> row) {
for (CorePattern pattern : row) {
if (pattern.isAbsurd()) {
return Collections.emptyList();
CorePattern pattern = row.get(0);
List<Pair<? extends List<? extends CorePattern>, Integer>> newRows = new ArrayList<>(actualRows.size());
List<? extends CorePattern> rowTail = row.subList(1, row.size());
if (pattern.getBinding() != null) {
boolean allVars = true;
boolean hasVars = false;
Set<CoreDefinition> constructors = new LinkedHashSet<>();
for (Pair<? extends List<? extends CorePattern>, Integer> actualRow : actualRows) {
CorePattern actualPattern = actualRow.proj1.get(0);
if (actualPattern.isAbsurd()) { // TODO: This is not needed if implemented properly; see TODO below
indices.add(actualRow.proj2);
return true;
}
if (actualPattern.getBinding() == null) {
allVars = false;
if (actualPattern.getConstructor() != null) {
constructors.add(actualPattern.getConstructor());
}
} else {
hasVars = true;
}
}
}
if (actualRows.isEmpty()) {
return null;
}

int coveringIndex = -1;
Map<CoreBinding, List<Pair<Integer, CorePattern>>> substs = new HashMap<>();
for (int i = 0; i < actualRows.size(); i++) {
Map<CoreBinding, CorePattern> subst = new HashMap<>();
if (unify(actualRows.get(i), row, null, subst)) {
if (coveringIndex == -1) {
coveringIndex = i;
if (allVars) {
for (Pair<? extends List<? extends CorePattern>, Integer> actualRow : actualRows) {
newRows.add(new Pair<>(actualRow.proj1.subList(1, actualRow.proj1.size()), actualRow.proj2));
}
for (Map.Entry<CoreBinding, CorePattern> entry : subst.entrySet()) {
substs.computeIfAbsent(entry.getKey(), k -> new ArrayList<>()).add(new Pair<>(i, entry.getValue()));
return computeCovering(newRows, rowTail, indices);
}

List<Pair<CoreDefinition, CoreParameter>> consWithParams;
CoreExpression type = pattern.getBinding().getTypeExpr();
if (type != null) type = type.unfoldType();
if (type instanceof CoreSigmaExpression sigmaExpr) {
consWithParams = Collections.singletonList(new Pair<>(null, sigmaExpr.getParameters()));
} else if (type instanceof CoreClassCallExpression classCall) {
consWithParams = Collections.singletonList(new Pair<>(null, classCall.getClassFieldParameters()));
} else {
List<CoreExpression.ConstructorWithDataArguments> constructorsWithArgs = type == null ? null : type.computeMatchedConstructorsWithDataArguments();
if (constructorsWithArgs == null) {
// TODO: We should return null here, but to do this we need to substitute previous patterns in type
consWithParams = new ArrayList<>(constructors.size());
for (CoreDefinition constructor : constructors) {
consWithParams.add(new Pair<>(constructor, constructor.getParameters()));
}
} else {
if (!hasVars) {
for (CoreExpression.ConstructorWithDataArguments cons : constructorsWithArgs) {
if (!constructors.contains(cons.getConstructor())) {
return false;
}
}
}
consWithParams = new ArrayList<>(constructorsWithArgs.size());
for (CoreExpression.ConstructorWithDataArguments cons : constructorsWithArgs) {
consWithParams.add(new Pair<>(cons.getConstructor(), cons.getParameters()));
}
}
}
}

if (coveringIndex == -1) {
return null;
}
if (substs.isEmpty()) {
return Collections.singletonList(coveringIndex);
}
for (Pair<CoreDefinition, CoreParameter> pair : consWithParams) {
List<CorePattern> newRow = new ArrayList<>();
for (CoreParameter param = pair.proj2; param.hasNext(); param = param.getNext()) {
newRow.add(new ArendPattern(new MyBinding(param.getTypeExpr()), null, Collections.emptyList(), null, null));
}
newRow.addAll(rowTail);

newRows.clear();
for (Pair<? extends List<? extends CorePattern>, Integer> actualRow : actualRows) {
CorePattern actualPattern = actualRow.proj1.get(0);
if (actualPattern.getBinding() != null) {
List<CorePattern> newActualRow = new ArrayList<>();
for (CoreParameter param = pair.proj2; param.hasNext(); param = param.getNext()) {
newActualRow.add(new ArendPattern(new MyBinding(null), null, Collections.emptyList(), null, null));
}
newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size()));
newRows.add(new Pair<>(newActualRow, actualRow.proj2));
} else if (actualPattern.getConstructor() == pair.proj1) {
List<CorePattern> newActualRow = new ArrayList<>();
newActualRow.addAll(actualPattern.getSubPatterns());
newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size()));
newRows.add(new Pair<>(newActualRow, actualRow.proj2));
}
}

Set<Integer> coveringIndices = new HashSet<>();
for (Map.Entry<CoreBinding, List<Pair<Integer, CorePattern>>> entry : substs.entrySet()) {
if (!checkCoverage(entry.getValue(), entry.getKey().getTypeExpr(), coveringIndices)) {
return null;
if (!computeCovering(newRows, newRow, indices)) {
return false;
}
}
}

List<Integer> coveringList = new ArrayList<>();
for (int i = 0; i < actualRows.size(); i++) {
if (coveringIndices.contains(i)) {
coveringList.add(i);
return true;
} else {
List<CorePattern> newRow = new ArrayList<>();
newRow.addAll(pattern.getSubPatterns());
newRow.addAll(rowTail);

for (Pair<? extends List<? extends CorePattern>, Integer> actualRow : actualRows) {
CorePattern actualPattern = actualRow.proj1.get(0);
if (actualPattern.getBinding() != null) {
List<CorePattern> newActualRow = new ArrayList<>();
for (CorePattern ignored : pattern.getSubPatterns()) {
newActualRow.add(new ArendPattern(new MyBinding(null), null, Collections.emptyList(), null, null));
}
newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size()));
newRows.add(new Pair<>(newActualRow, actualRow.proj2));
} else if (actualPattern.getConstructor() == pattern.getConstructor()) {
List<CorePattern> newActualRow = new ArrayList<>();
newActualRow.addAll(actualPattern.getSubPatterns());
newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size()));
newRows.add(new Pair<>(newActualRow, actualRow.proj2));
}
}

return computeCovering(newRows, newRow, indices);
}
return coveringList;
}

public static List<CorePattern> subst(Collection<? extends CorePattern> patterns, Map<? extends CoreBinding, ? extends CorePattern> map) {
Expand Down

0 comments on commit 8dffb0b

Please sign in to comment.