From 8dffb0b32a81d285b42da7224423ff391fbb5171 Mon Sep 17 00:00:00 2001 From: valis Date: Thu, 29 Feb 2024 19:46:39 +0300 Subject: [PATCH] Fix a bug with patterns --- .../org/arend/lib/pattern/PatternUtils.java | 256 ++++++++++-------- 1 file changed, 144 insertions(+), 112 deletions(-) diff --git a/meta/src/main/java/org/arend/lib/pattern/PatternUtils.java b/meta/src/main/java/org/arend/lib/pattern/PatternUtils.java index 0a8dcde5..fc81c23a 100644 --- a/meta/src/main/java/org/arend/lib/pattern/PatternUtils.java +++ b/meta/src/main/java/org/arend/lib/pattern/PatternUtils.java @@ -9,12 +9,9 @@ import org.arend.ext.core.body.CorePattern; import org.arend.ext.core.context.CoreBinding; import org.arend.ext.core.context.CoreParameter; -import org.arend.ext.core.definition.CoreConstructor; import org.arend.ext.core.definition.CoreDefinition; -import org.arend.ext.core.definition.CoreFunctionDefinition; import org.arend.ext.core.expr.*; import org.arend.ext.core.level.LevelSubstitution; -import org.arend.ext.core.ops.NormalizationMode; import org.arend.ext.reference.ArendRef; import org.arend.ext.typechecking.ExpressionTypechecker; import org.arend.ext.typechecking.TypedExpression; @@ -266,148 +263,183 @@ public static CoreExpression eval(CoreElimBody body, LevelSubstitution levelSubs return null; } - /** - * Checks coverage for a list of patterns with {@code type} as their type. + * @return indices of rows from {@code actualRows} that cover {@code row}, or {@code null} if {@code actualRows} do not cover {@code row} */ - private static boolean checkCoverage(List> patterns, CoreExpression type, Set result) { - for (Pair pair : patterns) { - if (pair.proj2.isAbsurd() || pair.proj2.getBinding() != null || pair.proj2.getConstructor() instanceof CoreFunctionDefinition) { - result.add(pair.proj1); - return true; + public static List computeCovering(List> actualRows, List row) { + for (CorePattern pattern : row) { + if (pattern.isAbsurd()) { + return Collections.emptyList(); } } - - type = type.normalize(NormalizationMode.WHNF); - - if (patterns.isEmpty()) { - if (!(type instanceof CoreDataCallExpression)) { - return false; - } - List constructors = ((CoreDataCallExpression) type).computeMatchedConstructors(); - return constructors != null && constructors.isEmpty(); + if (actualRows.isEmpty()) { + return null; } - boolean isTuple = patterns.get(0).proj2.getConstructor() == null; - for (Pair pair : patterns) { - if (isTuple != (pair.proj2.getConstructor() == null)) { - return false; - } + List, Integer>> rows = new ArrayList<>(actualRows.size()); + for (int i = 0; i < actualRows.size(); i++) { + rows.add(new Pair<>(actualRows.get(i), i)); } - - if (isTuple) { - CoreParameter parameters; - if (type instanceof CoreSigmaExpression) { - parameters = ((CoreSigmaExpression) type).getParameters(); - } else if (type instanceof CoreClassCallExpression) { - parameters = ((CoreClassCallExpression) type).getClassFieldParameters(); - } else { - return false; + Set indices = new HashSet<>(); + if (!computeCovering(rows, row, indices)) { + return null; + } + List result = new ArrayList<>(actualRows.size()); + for (int i = 0; i < actualRows.size(); i++) { + if (indices.contains(i)) { + result.add(i); } - return checkCoverage(patterns, parameters, result); } + return result; + } - List constructors = type.computeMatchedConstructorsWithDataArguments(); - if (constructors == null) { + private static boolean computeCovering(List, Integer>> actualRows, List row, Set indices) { + if (actualRows.isEmpty()) { return false; } - - Map>> map = new HashMap<>(); - for (Pair pair : patterns) { - map.computeIfAbsent(pair.proj2.getConstructor(), k -> new ArrayList<>()).add(pair); - } - - for (CoreDataCallExpression.ConstructorWithDataArguments conCall : constructors) { - List> list = map.get(conCall.getConstructor()); - if (list == null || !checkCoverage(list, conCall.getParameters(), result)) { - return false; - } + if (row.isEmpty()) { + indices.add(actualRows.get(0).proj2); + return true; } - return true; - } + class MyBinding implements CoreBinding { + final CoreExpression type; - private static boolean checkCoverage(List> rows, CoreParameter parameters, Set result) { - int numberOfColumns = rows.get(0).proj2.getSubPatterns().size(); - for (Pair row : rows) { - if (row.proj2.getSubPatterns().size() != numberOfColumns) { - return false; + MyBinding(CoreExpression type) { + this.type = type; } - } - CoreParameter param = parameters; - for (int i = 0; i < numberOfColumns; i++) { - if (!param.hasNext()) { - return false; - } - List> column = new ArrayList<>(rows.size()); - for (Pair row : rows) { - column.add(new Pair<>(row.proj1, row.proj2.getSubPatterns().get(i))); - } - if (!checkCoverage(column, param.getTypeExpr(), result)) { - return false; + @Override + public CoreExpression getTypeExpr() { + return type; } - param = param.getNext(); - } - if (param.hasNext()) { - return false; - } + @Override + public CoreReferenceExpression makeReference() { + return null; + } - if (numberOfColumns == 0) { - result.add(rows.get(0).proj1); + @Override + public String getName() { + return null; + } } - return true; - } - /** - * @return indices of rows from {@code actualRows} that cover {@code row}, or {@code null} if {@code actualRows} do not cover {@code row} - */ - public static List computeCovering(List> actualRows, List row) { - for (CorePattern pattern : row) { - if (pattern.isAbsurd()) { - return Collections.emptyList(); + CorePattern pattern = row.get(0); + List, Integer>> newRows = new ArrayList<>(actualRows.size()); + List rowTail = row.subList(1, row.size()); + if (pattern.getBinding() != null) { + boolean allVars = true; + boolean hasVars = false; + Set constructors = new LinkedHashSet<>(); + for (Pair, Integer> actualRow : actualRows) { + CorePattern actualPattern = actualRow.proj1.get(0); + if (actualPattern.isAbsurd()) { // TODO: This is not needed if implemented properly; see TODO below + indices.add(actualRow.proj2); + return true; + } + if (actualPattern.getBinding() == null) { + allVars = false; + if (actualPattern.getConstructor() != null) { + constructors.add(actualPattern.getConstructor()); + } + } else { + hasVars = true; + } } - } - if (actualRows.isEmpty()) { - return null; - } - int coveringIndex = -1; - Map>> substs = new HashMap<>(); - for (int i = 0; i < actualRows.size(); i++) { - Map subst = new HashMap<>(); - if (unify(actualRows.get(i), row, null, subst)) { - if (coveringIndex == -1) { - coveringIndex = i; + if (allVars) { + for (Pair, Integer> actualRow : actualRows) { + newRows.add(new Pair<>(actualRow.proj1.subList(1, actualRow.proj1.size()), actualRow.proj2)); } - for (Map.Entry entry : subst.entrySet()) { - substs.computeIfAbsent(entry.getKey(), k -> new ArrayList<>()).add(new Pair<>(i, entry.getValue())); + return computeCovering(newRows, rowTail, indices); + } + + List> consWithParams; + CoreExpression type = pattern.getBinding().getTypeExpr(); + if (type != null) type = type.unfoldType(); + if (type instanceof CoreSigmaExpression sigmaExpr) { + consWithParams = Collections.singletonList(new Pair<>(null, sigmaExpr.getParameters())); + } else if (type instanceof CoreClassCallExpression classCall) { + consWithParams = Collections.singletonList(new Pair<>(null, classCall.getClassFieldParameters())); + } else { + List constructorsWithArgs = type == null ? null : type.computeMatchedConstructorsWithDataArguments(); + if (constructorsWithArgs == null) { + // TODO: We should return null here, but to do this we need to substitute previous patterns in type + consWithParams = new ArrayList<>(constructors.size()); + for (CoreDefinition constructor : constructors) { + consWithParams.add(new Pair<>(constructor, constructor.getParameters())); + } + } else { + if (!hasVars) { + for (CoreExpression.ConstructorWithDataArguments cons : constructorsWithArgs) { + if (!constructors.contains(cons.getConstructor())) { + return false; + } + } + } + consWithParams = new ArrayList<>(constructorsWithArgs.size()); + for (CoreExpression.ConstructorWithDataArguments cons : constructorsWithArgs) { + consWithParams.add(new Pair<>(cons.getConstructor(), cons.getParameters())); + } } } - } - if (coveringIndex == -1) { - return null; - } - if (substs.isEmpty()) { - return Collections.singletonList(coveringIndex); - } + for (Pair pair : consWithParams) { + List newRow = new ArrayList<>(); + for (CoreParameter param = pair.proj2; param.hasNext(); param = param.getNext()) { + newRow.add(new ArendPattern(new MyBinding(param.getTypeExpr()), null, Collections.emptyList(), null, null)); + } + newRow.addAll(rowTail); + + newRows.clear(); + for (Pair, Integer> actualRow : actualRows) { + CorePattern actualPattern = actualRow.proj1.get(0); + if (actualPattern.getBinding() != null) { + List newActualRow = new ArrayList<>(); + for (CoreParameter param = pair.proj2; param.hasNext(); param = param.getNext()) { + newActualRow.add(new ArendPattern(new MyBinding(null), null, Collections.emptyList(), null, null)); + } + newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size())); + newRows.add(new Pair<>(newActualRow, actualRow.proj2)); + } else if (actualPattern.getConstructor() == pair.proj1) { + List newActualRow = new ArrayList<>(); + newActualRow.addAll(actualPattern.getSubPatterns()); + newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size())); + newRows.add(new Pair<>(newActualRow, actualRow.proj2)); + } + } - Set coveringIndices = new HashSet<>(); - for (Map.Entry>> entry : substs.entrySet()) { - if (!checkCoverage(entry.getValue(), entry.getKey().getTypeExpr(), coveringIndices)) { - return null; + if (!computeCovering(newRows, newRow, indices)) { + return false; + } } - } - List coveringList = new ArrayList<>(); - for (int i = 0; i < actualRows.size(); i++) { - if (coveringIndices.contains(i)) { - coveringList.add(i); + return true; + } else { + List newRow = new ArrayList<>(); + newRow.addAll(pattern.getSubPatterns()); + newRow.addAll(rowTail); + + for (Pair, Integer> actualRow : actualRows) { + CorePattern actualPattern = actualRow.proj1.get(0); + if (actualPattern.getBinding() != null) { + List newActualRow = new ArrayList<>(); + for (CorePattern ignored : pattern.getSubPatterns()) { + newActualRow.add(new ArendPattern(new MyBinding(null), null, Collections.emptyList(), null, null)); + } + newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size())); + newRows.add(new Pair<>(newActualRow, actualRow.proj2)); + } else if (actualPattern.getConstructor() == pattern.getConstructor()) { + List newActualRow = new ArrayList<>(); + newActualRow.addAll(actualPattern.getSubPatterns()); + newActualRow.addAll(actualRow.proj1.subList(1, actualRow.proj1.size())); + newRows.add(new Pair<>(newActualRow, actualRow.proj2)); + } } + + return computeCovering(newRows, newRow, indices); } - return coveringList; } public static List subst(Collection patterns, Map map) {