Skip to content

Commit

Permalink
Fix a bug with pattern matching with idp
Browse files Browse the repository at this point in the history
  • Loading branch information
valis committed Mar 2, 2024
1 parent 13f65cb commit e9abd43
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 40 deletions.
30 changes: 12 additions & 18 deletions meta/src/main/java/org/arend/lib/meta/SimpCoeMeta.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected Spec(List<ConcreteLetClause> letClauses, ConcreteExpression concreteAr
this.isForward = isForward;
}

abstract ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight);
abstract ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight);

ConcreteExpression makeArg(ConcreteExpression arg, ConcreteFactory factory) {
return isForward ? factory.app(factory.ref(ext.inv.getRef()), true, arg) : arg;
Expand Down Expand Up @@ -95,7 +95,7 @@ protected ErrorSpec() {
}

@Override
ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
return null;
}
}
Expand All @@ -118,7 +118,7 @@ private EqualitySpec(CoreParameter lamParam, CoreFunCallExpression equality, Exp
}

@Override
public ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
public ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
SimpCoeMeta meta = isForward ? ext.simpCoeFMeta : ext.simpCoeMeta;
return factory.app(factory.ref((isLeftConst ? meta.transport_path_pmap_right : meta.transport_path_pmap).getRef()), true, Arrays.asList(factory.core(leftFunc.computeTyped()), factory.core(rightFunc.computeTyped()), transportPathArg, factory.core(transportValueArg.computeTyped()), factory.core(eqRight.computeTyped()), arg == null ? argument : factory.core(arg)));
}
Expand All @@ -145,12 +145,12 @@ void excessiveArgsError(List<ConcreteArgument> excessiveArgs, ExpressionTypechec
}

@Override
public ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
public ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
List<ConcreteCaseArgument> caseArgs = new ArrayList<>(4);
List<ConcretePattern> casePatterns = new ArrayList<>(4);
ArendRef rightRef = factory.local("r");
List<ConcreteExpression> rightRefs = Collections.singletonList(factory.ref(rightRef));
caseArgs.add(factory.caseArg(transportRightArg, rightRef, null));
caseArgs.add(factory.caseArg(transportRightArg, rightRef, transportTypeArg == null ? null : factory.core(transportTypeArg.computeTyped())));

ArendRef pathRef = factory.local("q");
List<PathExpression> pathRefs = Collections.singletonList(new PathExpression(factory.ref(pathRef)));
Expand Down Expand Up @@ -208,7 +208,7 @@ private class PiArgsSpec extends Spec {
}

@Override
public ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
public ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
ConcreteExpression concreteValueArg = makeConcreteValueArg(transportValueArg, factory);
ArendRef jRef = factory.local("q");
ArendRef transportRef = factory.local("z");
Expand Down Expand Up @@ -252,7 +252,7 @@ ConcreteExpression proj(ConcreteExpression expr, int i, ConcreteFactory factory)
}

@Override
ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
ArendRef jLamRef1 = factory.local("a''");
ArendRef jLamRef2 = factory.local("q");
ArendRef jPiRef = factory.local("s'");
Expand Down Expand Up @@ -310,7 +310,7 @@ private ConcreteExpression proj(ConcreteExpression expr, ConcreteFactory factory
}

@Override
public ConcreteExpression make(ConcreteFactory factory, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
public ConcreteExpression make(ConcreteFactory factory, CoreExpression transportTypeArg, ConcreteExpression transportLeftArg, ConcreteExpression transportRightArg, ConcreteExpression transportPathArg, CoreExpression transportValueArg, CoreExpression eqRight) {
ArendRef jRef = factory.local("q");
ConcreteExpression concreteValueArg = makeConcreteValueArg(transportValueArg, factory);
ConcreteExpression jTypeLeft = proj(factory.app(factory.ref(ext.transport.getRef()), true, Arrays.asList(factory.core(transportLam.computeTyped()), factory.ref(jRef), concreteValueArg)), factory);
Expand All @@ -328,12 +328,7 @@ public ConcreteExpression make(ConcreteFactory factory, ConcreteExpression trans

private Spec getSpec(CoreExpression arg, ExpressionTypechecker typechecker, ConcreteSourceNode marker, ConcreteFactory factory, List<CoreExpression> args, CoreClassField field, int proj, ConcreteExpression concreteArg, TypedExpression simpCoeArg, List<ConcreteArgument> excessiveArgs, boolean isForward) {
arg = arg.normalize(NormalizationMode.WHNF);
if (!(arg instanceof CoreLamExpression)) {
return null;
}

CoreLamExpression lam = (CoreLamExpression) arg;
if (lam.getParameters().getNext().hasNext()) {
if (!(arg instanceof CoreLamExpression lam) || lam.getParameters().getNext().hasNext()) {
return null;
}

Expand Down Expand Up @@ -400,8 +395,7 @@ private Spec getSpec(CoreExpression arg, ExpressionTypechecker typechecker, Conc
bindings.add(param.getBinding());
}

if (classFields != null && concreteArg instanceof ConcreteClassExtExpression) {
ConcreteClassExtExpression classExt = (ConcreteClassExtExpression) concreteArg;
if (classFields != null && concreteArg instanceof ConcreteClassExtExpression classExt) {
ConcreteExpression baseExpr = classExt.getBaseClassExpression();
CoreDefinition def = baseExpr instanceof ConcreteReferenceExpression ? ext.definitionProvider.getCoreDefinition(((ConcreteReferenceExpression) baseExpr).getReferent()) : null;
CoreClassDefinition classDef = ((CoreClassCallExpression) body).getDefinition();
Expand Down Expand Up @@ -525,7 +519,7 @@ private Spec getSpec(CoreExpression arg, ExpressionTypechecker typechecker, Conc
if (spec instanceof ErrorSpec) return null;
if (spec != null) {
spec.excessiveArgsError(excessiveArgs, typechecker);
return typechecker.typecheck(spec.make(factory, factory.core(transportArgs.get(2).computeTyped()), factory.core(transportArgs.get(3).computeTyped()), factory.core(transportArgs.get(4).computeTyped()), transportArgs.get(5), equality.getDefCallArguments().get(2)), contextData.getExpectedType());
return typechecker.typecheck(spec.make(factory, transportArgs.get(0), factory.core(transportArgs.get(2).computeTyped()), factory.core(transportArgs.get(3).computeTyped()), factory.core(transportArgs.get(4).computeTyped()), transportArgs.get(5), equality.getDefCallArguments().get(2)), contextData.getExpectedType());
}
} else {
if (leftExpr instanceof CoreFunCallExpression && ((CoreFunCallExpression) leftExpr).getDefinition() == ext.prelude.getCoerce()) {
Expand All @@ -537,7 +531,7 @@ private Spec getSpec(CoreExpression arg, ExpressionTypechecker typechecker, Conc
if (spec != null) {
spec.excessiveArgsError(excessiveArgs, typechecker);
ArendRef iRef = factory.local("i");
return typechecker.typecheck(spec.make(factory, factory.ref(ext.prelude.getLeft().getRef()), factory.ref(ext.prelude.getRight().getRef()), factory.app(factory.ref(ext.prelude.getPathConRef()), true, Collections.singletonList(factory.lam(Collections.singletonList(factory.param(iRef)), factory.ref(iRef)))), coeArgs.get(1), equality.getDefCallArguments().get(2)), contextData.getExpectedType());
return typechecker.typecheck(spec.make(factory, null, factory.ref(ext.prelude.getLeft().getRef()), factory.ref(ext.prelude.getRight().getRef()), factory.app(factory.ref(ext.prelude.getPathConRef()), true, Collections.singletonList(factory.lam(Collections.singletonList(factory.param(iRef)), factory.ref(iRef)))), coeArgs.get(1), equality.getDefCallArguments().get(2)), contextData.getExpectedType());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.arend.ext.core.definition.CoreClassField;
import org.arend.ext.core.definition.CoreConstructor;
import org.arend.ext.core.expr.*;
import org.arend.ext.core.ops.CMP;
import org.arend.ext.core.ops.NormalizationMode;
import org.arend.ext.error.ErrorReporter;
import org.arend.ext.reference.ArendRef;
Expand Down Expand Up @@ -546,7 +547,7 @@ private boolean typeToRule(TypedExpression typed, CoreBinding binding, boolean a
CoreExpression type = binding != null ? binding.getTypeExpr() : typed.getType();
CoreFunCallExpression eq = Utils.toEquality(type, null, null);
if (eq == null) {
CoreExpression typeNorm = type.normalize(NormalizationMode.WHNF).getUnderlyingExpression();
CoreExpression typeNorm = type.normalize(NormalizationMode.WHNF);
if (!(typeNorm instanceof CoreClassCallExpression classCall)) {
return false;
}
Expand All @@ -560,6 +561,10 @@ private boolean typeToRule(TypedExpression typed, CoreBinding binding, boolean a
(!isRDiv || typeToRule(typechecker.typecheck(factory.app(factory.ref(meta.rdiv.getPersonalFields().get(0).getRef()), false, args), null), null, true, rules));
}

if (!typechecker.compare(eq.getDefCallArguments().get(0), getValuesType(), CMP.EQ, refExpr, false, true, false)) {
return false;
}

List<Integer> lhs = new ArrayList<>();
List<Integer> rhs = new ArrayList<>();
ConcreteExpression lhsTerm = computeTerm(eq.getDefCallArguments().get(1), lhs);
Expand Down
4 changes: 2 additions & 2 deletions src/Algebra/Domain/Bezout.ard
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
: ∃ (r : E) (\Pi (j : Fin n) -> LDiv (d j) (r - a j)) \elim n
| 0 => inP (0, \case __)
| suc n =>
\have | (inP (q,f)) => chinese (tail a) (tail d) \lam i j i/=j => p (suc i) (suc j) (fsuc/= i/=j)
| (inP (r,p1,p2)) => chinese2 (a 0) q (d 0) (BigProd (tail d)) $ IsCoprime_BigProd-right $ later \lam j => p 0 (suc j) (\case __)
\have | (inP (q,f)) => chinese (taild a) (taild d) \lam i j i/=j => p (suc i) (suc j) (fsuc/= i/=j)
| (inP (r,p1,p2)) => chinese2 (a 0) q (d 0) (BigProd (taild d)) $ IsCoprime_BigProd-right $ later \lam j => p 0 (suc j) (\case __)
\in inP (r, \case \elim __ \with {
| 0 => p1
| suc j => transport (LDiv _) equation (LDiv_+ (LDiv.trans (LDiv_BigProd j) p2) (f j))
Expand Down
2 changes: 1 addition & 1 deletion src/Algebra/Linear/Solver.ard
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
(h : DArray (\lam j => interpretEq (p j))) : interpret t1 < interpret t2
=> \case <_*-cancel-left (natCoef (c.1 0)) (interpret t1) (interpret t2) (\case or.toOr c.4 \with {
| byLeft q => <_+-invert-right (certToLeq (toContr p t1 t2) c) $ solveContrProblem.aux p (\lam j => c.1 (suc j)) (hasNegative-correct p _ q) h
| byRight q => <_+-invert-left (certToLess (toContr p t1 t2) c q) (solveContrProblem.aux_<= p (tail c.1) h)
| byRight q => <_+-invert-left (certToLess (toContr p t1 t2) c q) (solveContrProblem.aux_<= p (taild c.1) h)
}) \with {
| byLeft r => r.2
| byRight r => absurd $ <-irreflexive $ <-transitive-left r.1 natCoef>=0
Expand Down
2 changes: 1 addition & 1 deletion src/Algebra/Monoid.ard
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@
\lemma BigSum_EPerm {l l' : Array E} (e : EPerm l l') : BigSum l = BigSum l' \elim l, l', e
| nil, nil, eperm-nil => idp
| x :: l1, y :: l2, eperm-:: p e => pmap2 (+) p (BigSum_EPerm e)
| x :: (x' :: l1), y :: (y' :: l2), eperm-swap p q idp => equation
| x :: (x' :: l1), y :: (y' :: l2), eperm-swap p q r => equation {using (pmap BigSum r)}
| l, l', eperm-trans e1 e2 => BigSum_EPerm e1 *> BigSum_EPerm e2

\lemma BigSum_Perm {n : Nat} {l l' : Array E n} (p : Perm l l') : BigSum l = BigSum l'
Expand Down
36 changes: 21 additions & 15 deletions src/Data/Array.ard
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@
\func arrayExt {A : \Type} {n : Nat} {l l' : Array A n} (p : \Pi (j : Fin n) -> l j = l' j) : l = l'
=> path (\lam i => \new Array A n (\lam j => p j i))

\func array-unext {A : \Type} {n : Nat} {l l' : Array A n} (p : l = {Array A} l') (j : Fin n) : l j = l' j \elim p
| idp => idp
\func tail {A : \Type} (l : Array A) : Array A \elim l
| nil => nil
| _ :: a => a

\func tail {n : Nat} {A : Fin (suc n) -> \Type} (a : DArray A) : DArray (\lam j => A (suc j)) \elim a
\func taild {n : Nat} {A : Fin (suc n) -> \Type} (a : DArray A) : DArray (\lam j => A (suc j)) \elim a
| _ :: a => a

\func array-unext {A : \Type} {n : Nat} {l l' : Array A n} (p : l = {Array A} l') (j : Fin n) : l j = l' j \elim n, l, l', j
| suc n, a :: l, a1 :: l', 0 => unhead p
| suc n, a :: l, a1 :: l', suc j => array-unext (pmap tail p) j

\func len=0 {A : \Type} {l : Array A} (p : l.len = 0) : l = nil \elim l
| nil => idp

Expand Down Expand Up @@ -851,7 +856,7 @@
\func EPerm_map {A B : \Type} (f : A -> B) {l l' : Array A} (e : EPerm l l') : EPerm (map f l) (map f l') \elim l, l', e
| nil, nil, eperm-nil => eperm-nil
| x :: l, y :: l', eperm-:: p e => eperm-:: (pmap f p) (EPerm_map f e)
| x :: (x' :: l), y :: (y' :: l'), eperm-swap idp idp idp => eperm-swap idp idp idp
| x :: (x' :: l), y :: (y' :: l'), eperm-swap idp idp q => eperm-swap idp idp $ pmap (map f) q
| l, l', eperm-trans e1 e2 => eperm-trans (EPerm_map f e1) (EPerm_map f e2)
\where
\func conv {A B : \Set} (f : A -> B) (inj : isInj f) {l l' : Array A} (e : EPerm (map f l) (map f l')) : EPerm l l'
Expand All @@ -866,7 +871,7 @@
| yes _ => eperm-:: idp (EPerm_keep q)
| no _ => EPerm_keep q
}
| x :: (x' :: l), _ :: (_ :: _), eperm-swap idp idp idp => mcases \with {
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp r => rewrite r $ mcases \with {
| yes p, yes q => rewrite (dec_yes_reduce q, dec_yes_reduce p) (eperm-swap idp idp idp)
| yes p, no q => rewrite (dec_yes_reduce p, dec_no_reduce q) eperm-refl
| no p, yes q => rewrite (dec_no_reduce p, dec_yes_reduce q) eperm-refl
Expand All @@ -880,7 +885,7 @@
| yes p => EPerm_remove q
| no n => eperm-:: idp (EPerm_remove q)
}
| x :: (x' :: l), _ :: (_ :: _), eperm-swap idp idp idp => mcases \with {
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp r => rewrite r $ mcases \with {
| yes p, yes q => rewrite (dec_yes_reduce p, dec_yes_reduce q) eperm-refl
| yes p, no q => rewrite (dec_yes_reduce p, dec_no_reduce q) eperm-refl
| no p, yes q => rewrite (dec_no_reduce p, dec_yes_reduce q) eperm-refl
Expand All @@ -891,7 +896,7 @@
\func EPerm_nub {A : DecSet} {l l' : Array A} (p : EPerm l l') : EPerm (nub l) (nub l') \elim l, l', p
| nil, nil, eperm-nil => eperm-nil
| x :: l, _ :: l', eperm-:: idp q => eperm-:: idp $ EPerm_remove (EPerm_nub q)
| x :: (x' :: l), _ :: (_ :: _), eperm-swap idp idp idp => mcases contradiction \with {
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp r => rewrite r $ mcases contradiction \with {
| yes p, yes _ => eperm-:: p (eperm-= remove-swap)
| no q, no _ => eperm-swap idp idp remove-swap
}
Expand All @@ -910,7 +915,7 @@
| false => EPerm_filter f e
| true => eperm-:: idp (EPerm_filter f e)
}
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp idp => cases (f x, f x') eperm-refl \with {
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp r => rewrite r $ cases (f x, f x') eperm-refl \with {
| true, true => eperm-swap idp idp idp
}
| l, l', eperm-trans e1 e2 => eperm-trans (EPerm_filter f e1) (EPerm_filter f e2)
Expand Down Expand Up @@ -940,11 +945,12 @@
| 0 => inv p
| suc j => f j
})
| x :: (x' :: (l : Array)), y :: (y' :: _), eperm-swap p1 p2 idp => (transposition1 (0 : Fin (suc l.len)), \case \elim __ \with {
| 0 => rewrite transposition1.transposition1-left (inv p2)
| 1 => rewrite transposition1.transposition1-right (inv p1)
| suc (suc j) => rewrite (transposition1.transposition1_/= (later $ \case __) (later $ \case __)) idp
})
| x :: (x' :: (l : Array)), y :: (y' :: l'), eperm-swap p1 p2 q => transport (\lam (l'' : Array A) => \Sigma (e : Equiv {Fin (suc (suc l''.len))} {Fin (suc (suc l.len))}) (\Pi (j : Fin (suc (suc l''.len))) -> (y :: y' :: l'') j = (x :: x' :: l) (e j))) q
(transposition1 (0 : Fin (suc l.len)), \case \elim __ \with {
| 0 => rewrite transposition1.transposition1-left (inv p2)
| 1 => rewrite transposition1.transposition1-right (inv p1)
| suc (suc j) => rewrite (transposition1.transposition1_/= (later $ \case __) (later $ \case __)) idp
})
| l, l', eperm-trans p1 p2 =>
\have | (e1,f1) => eperm_equiv p1
| (e2,f2) => eperm_equiv p2
Expand All @@ -957,7 +963,7 @@
\lemma EPerm_len {A : \Type} {l l' : Array A} (e : EPerm l l') : l.len = l'.len \elim l, l', e
| nil, nil, eperm-nil => idp
| x :: l, y :: l', eperm-:: p e => pmap suc (EPerm_len e)
| x :: (x' :: l), y :: (y' :: l'), eperm-swap _ _ idp => idp
| x :: (x' :: l), y :: (y' :: l'), eperm-swap _ _ q => pmap (\lam x => suc (suc (DArray.len {x}))) q
| l, l', eperm-trans e1 e2 => EPerm_len e1 *> EPerm_len e2

\func EPermDec {A : DecSet} {l l' : Array A} : Or (EPerm l l') (Not (EPerm l l'))
Expand All @@ -972,7 +978,7 @@
| yes p => pmap suc (EPerm_count e a)
| no n => EPerm_count e a
}
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp idp => mcases \with {
| x :: (x' :: l), _ :: (_ :: l'), eperm-swap idp idp r => rewrite r $ mcases \with {
| yes p, yes p' => rewrite (decideEq=_reduce p, decideEq=_reduce p') idp
| yes p, no p' => rewrite (decideEq=_reduce p, decideEq/=_reduce p') idp
| no p, yes p' => rewrite (decideEq/=_reduce p, decideEq=_reduce p') idp
Expand Down
2 changes: 1 addition & 1 deletion src/Logic.ard
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@
\lemma aux2 {n : Nat} {l : Array \Prop (suc n)} (p : DArray {n} (\lam i => l i -> l (suc i))) (q : l n -> l 0) : TFAE l \elim n
| 0 => \lam (0) (0) a => a
| suc n =>
\have t => aux2 {n} {tail l} (\lam i => p (suc i)) \lam a => p 0 $ q $ transport l (nat_fin_= $ pmap suc (mod_< id<suc) *> inv (mod_< id<suc)) a
\have t => aux2 {n} {taild l} (\lam i => p (suc i)) \lam a => p 0 $ q $ transport l (nat_fin_= $ pmap suc (mod_< id<suc) *> inv (mod_< id<suc)) a
\in \case \elim __, \elim __ \with {
| 0, j => aux p (later idp)
| suc i, 0 => \lam a => q $ aux p (later $ pmap suc (<=_exists $ <_suc_<= $ fin_< i) *> inv (mod_< id<suc)) a
Expand Down
Loading

0 comments on commit e9abd43

Please sign in to comment.