diff --git a/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/SubstitutionMatchService.scala b/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/SubstitutionMatchService.scala index 3954be4..d75185a 100644 --- a/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/SubstitutionMatchService.scala +++ b/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/SubstitutionMatchService.scala @@ -21,7 +21,8 @@ class SubstitutionMatchService(val inkuireDb: InkuireDb) extends BaseMatchServic ) .flatMap { okTypes => State.get[TypingState].map { typingState => - if (okTypes) checkBindings(typingState.variableBindings) + if (okTypes) + checkBindings(typingState.variableBindings) else false } } @@ -43,7 +44,7 @@ class SubstitutionMatchService(val inkuireDb: InkuireDb) extends BaseMatchServic val actualSignatures = resolveResult.signatures.foldLeft(resolveResult.signatures) { case (acc, against) => acc.filter { sgn => - sgn == against || !(sgn.canSubstituteFor(against) && !against.canSubstituteFor(sgn)) + sgn == against || !sgn.canSubstituteFor(against) || against.canSubstituteFor(sgn) } } val actualSignaturesSize = actualSignatures.headOption.map(_.typesWithVariances.size) @@ -59,20 +60,26 @@ class SubstitutionMatchService(val inkuireDb: InkuireDb) extends BaseMatchServic } private def checkBindings(bindings: VariableBindings): Boolean = { - bindings.bindings.values.forall { types => + val bindingsCorrect = bindings.bindings.values.forall { types => types .sliding(2, 1) .forall { case (a: Type) :: (b: Type) :: Nil => - (ancestryGraph.getAllParentsITIDs(a).contains(b.itid.get) || - ancestryGraph.getAllParentsITIDs(b).contains(a.itid.get)) && - a.params.size == b.params.size && - a.params.map(_.typ).zip(b.params.map(_.typ)).forall { + val exactlyEqual = a == b + // Disable this check, since it can't handle transitive relations + // val relatedToEachOther = + // ancestryGraph.getAllParentsITIDs(a).contains(b.itid.get) || + // ancestryGraph.getAllParentsITIDs(b).contains(a.itid.get) + val sameSize = a.params.size == b.params.size + val sameTypes = a.params.map(_.typ).zip(b.params.map(_.typ)).forall { case (a: Type, b: Type) => a.itid == b.itid case _ => false } + exactlyEqual || (sameSize && sameTypes) case _ => true } - } && !TypeVariablesGraph(bindings).hasCyclicDependency + } + val noCycles = !TypeVariablesGraph(bindings).hasCyclicDependency + bindingsCorrect && noCycles } } diff --git a/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/TopLevelMatchQualityService.scala b/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/TopLevelMatchQualityService.scala index 311d1ba..63fcfa8 100644 --- a/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/TopLevelMatchQualityService.scala +++ b/engine/src/main/scala/org/virtuslab/inkuire/engine/impl/service/TopLevelMatchQualityService.scala @@ -5,11 +5,18 @@ import org.virtuslab.inkuire.engine.impl.model._ class TopLevelMatchQualityService(val db: InkuireDb) extends BaseMatchQualityService with MatchingOps { - def matchQualityMetric(AnnotatedSignature: AnnotatedSignature, matching: Signature): Int = - variancesMatchQualityMetric( - AnnotatedSignature.signature.typesWithVariances, + def matchQualityMetric(annotatedSignature: AnnotatedSignature, matching: Signature): Int = { + val mq = variancesMatchQualityMetric( + annotatedSignature.signature.typesWithVariances, matching.typesWithVariances ) + val multi = byFunctionNameMetricMultiplier(annotatedSignature.name) + (mq * multi).toInt + } + + def byFunctionNameMetricMultiplier(name: String): Double = + if (List('$', '`').exists(name.contains(_))) 5.0 + else 1.0 + (name.length / 10.0) def variancesMatchQualityMetric(typVariances: Seq[Variance], suprVariances: Seq[Variance]): Int = typVariances.zip(suprVariances).map { case (v1, v2) => varianceMatchQualityMetric(v1, v2) }.sum diff --git a/engine/src/test/scala/org/virtuslab/inkuire/engine/BaseEndToEndEngineTest.scala b/engine/src/test/scala/org/virtuslab/inkuire/engine/BaseEndToEndEngineTest.scala index 0e2554d..b50e07f 100644 --- a/engine/src/test/scala/org/virtuslab/inkuire/engine/BaseEndToEndEngineTest.scala +++ b/engine/src/test/scala/org/virtuslab/inkuire/engine/BaseEndToEndEngineTest.scala @@ -23,7 +23,8 @@ trait BaseEndToEndEngineTest { */ def testFunctionFound(signature: String, funName: String)(implicit loc: munit.Location): Unit = { test(s"$funName : $signature") { - assert(testService().query(signature).exists(_.name == funName)) + val sigs = testService().query(signature) + assert(sigs.exists(_.name == funName)) } } } diff --git a/engine/src/test/scala/org/virtuslab/inkuire/engine/EndToEndEngineTest.scala b/engine/src/test/scala/org/virtuslab/inkuire/engine/EndToEndEngineTest.scala index 86b44f3..79c46ee 100644 --- a/engine/src/test/scala/org/virtuslab/inkuire/engine/EndToEndEngineTest.scala +++ b/engine/src/test/scala/org/virtuslab/inkuire/engine/EndToEndEngineTest.scala @@ -27,7 +27,7 @@ class EndToEndEngineTest extends munit.FunSuite with BaseEndToEndEngineTest { testFunctionFound( "IArray[Float] => (Float => Boolean) => Boolean", "IArray.forall" - ) // TODO(kπ) IMHO should be just `forall` (generation bug) + ) testFunctionFound("List[A] => B => (B => A => B) => B", "foldLeft") @@ -37,12 +37,15 @@ class EndToEndEngineTest extends munit.FunSuite with BaseEndToEndEngineTest { testFunctionFound("F[A] => B => ((B, A) => B) => B", "foldLeft") - // TODO(kπ) this is a bug in constraint checking - // testFunctionFound("List[A] => A => (A => A => A) => A", "foldLeft") + testFunctionFound("List[A] => A => (A => A => A) => A", "foldLeft") - // testFunctionFound("List[A] => A => ((A, A) => A) => A", "foldLeft") + testFunctionFound("List[A] => A => ((A, A) => A) => A", "foldLeft") - // testFunctionFound("F[A] => A => (A => A => A) => A", "foldLeft") + testFunctionFound("F[A] => A => (A => A => A) => A", "foldLeft") - // testFunctionFound("F[A] => A => ((A, A) => A) => A", "foldLeft") + testFunctionFound("F[A] => A => ((A, A) => A) => A", "foldLeft") + + testFunctionFound("List[A] => A => Boolean", "contains") + + testFunctionFound("Seq[A] => A => Boolean", "contains") }