Skip to content

Commit

Permalink
Revise SepCheck.checkType
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Jan 24, 2025
1 parent 0b9acb3 commit 718e03d
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 57 deletions.
147 changes: 120 additions & 27 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,28 @@ import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
import CaptureSet.{Refs, emptySet, HiddenSet}
import config.Printers.capt
import StdNames.nme
import util.{SimpleIdentitySet, EqHashMap}
import util.{SimpleIdentitySet, EqHashMap, SrcPos}

object SepChecker:

/** Enumerates kinds of captures encountered so far */
enum Captures:
case None
case Explicit // one or more explicitly declared captures
case Hidden // exacttly one hidden captures
case NeedsCheck // one hidden capture and one other capture (hidden or declared)

def add(that: Captures): Captures =
if this == None then that
else if that == None then this
else if this == Explicit && that == Explicit then Explicit
else NeedsCheck
end Captures

class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
import tpd.*
import checker.*
import SepChecker.*

/** The set of capabilities that are hidden by a polymorphic result type
* of some previous definition.
Expand Down Expand Up @@ -59,6 +76,10 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
recur(refs)
end hidden

private def containsHidden(using Context): Boolean =
refs.exists: ref =>
!hiddenByElem(ref, _ => emptySet).isEmpty

/** Deduct the footprint of `sym` and `sym*` from `refs` */
private def deductSym(sym: Symbol)(using Context) =
val ref = sym.termRef
Expand Down Expand Up @@ -183,6 +204,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
for (arg, idx) <- indexedArgs do
if arg.needsSepCheck then
val ac = formalCaptures(arg)
checkType(arg.formalType, arg.srcPos, NoSymbol, " the argument's adapted type")
val hiddenInArg = ac.hidden.footprint
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
Expand All @@ -209,32 +231,103 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if !overlap.isEmpty then
sepUseError(tree, usedFootprint, overlap)

def checkType(tpt: Tree, sym: Symbol)(using Context) =
def checkSep(hidden: Refs, footprint: Refs, descr: String) =
val overlap = hidden.overlapWith(footprint)
if !overlap.isEmpty then
report.error(
em"""Separation failure: ${tpt.nuType} captures a root element hiding ${CaptureSet(hidden)}
|and also $descr ${CaptureSet(footprint)}.
|The two sets overlap at ${CaptureSet(overlap)}""",
tpt.srcPos)

val capts = CaptureSet.ofTypeDeeply(tpt.nuType,
union = (xs, ys) => ctx ?=> CaptureSet(xs.elems ++ ys.elems))
.elems
// Note: Can't use captures(...) or deepCaptureSet here because these would map
// e.g (Object^{<cap hiding x}, Object^{x}) to {<cap hiding x>} and we need
// {<cap hiding x>, x} instead.
val explicitFootprint = capts.footprint
var hiddenFootprint: Refs = emptySet
//println(i"checking tp ${tpt.tpe} with $capts fp $explicitFootprint")
for ref <- capts do
val hidden0 = hiddenByElem(ref, _.hidden).footprint
val hiddenByRef = hiddenByElem(ref, _.hidden).footprint.deductSym(sym)
if !hiddenByRef.isEmpty then
checkSep(hiddenByRef, explicitFootprint, "refers to")
checkSep(hiddenByRef, hiddenFootprint, "captures another root element hiding")
hiddenFootprint ++= hiddenByRef
def checkType(tpt: Tree, sym: Symbol)(using Context): Unit =
checkType(tpt.nuType, tpt.srcPos, sym, "")

/** Check that all parts of type `tpe` are separated.
* @param tpe the type to check
* @param pos position for error reporting
* @param sym if `tpe` is the (result-) type of a val or def, the symbol of
* this definition, otherwise NoSymbol. If `sym` exists we
* deduct its associated direct and reach capabilities everywhere
* from the capture sets we check.
* @param what a string describing what kind of type it is
*/
def checkType(tpe: Type, pos: SrcPos, sym: Symbol, what: String)(using Context): Unit =

def checkParts(parts: List[Type]): Unit =
var footprint: Refs = emptySet
var hiddenSet: Refs = emptySet
var checked = 0
for part <- parts do

/** Report an error if `current` and `next` overlap.
* @param current the footprint or hidden set seen so far
* @param next the footprint or hidden set of the next part
* @param mapRefs a function over the capture set elements of the next part
* that returns the references of the same kind as `current`
* (i.e. the part's footprint or hidden set)
* @param prevRel a verbal description of current ("references or "hides")
* @param nextRel a verbal descriiption of next
*/
def checkSep(current: Refs, next: Refs, mapRefs: Refs => Refs, prevRel: String, nextRel: String): Unit =
val globalOverlap = current.overlapWith(next)
if !globalOverlap.isEmpty then
val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
.map: prev =>
val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym)
(i", $prev , ", prevRefs, prevRefs.overlapWith(next))
.dropWhile(_._3.isEmpty)
.nextOption
.getOrElse(("", current, globalOverlap))
report.error(
em"""Separation failure in$what type $tpe.
|One part, $part , $nextRel ${CaptureSet(next)}.
|A previous part$prevStr $prevRel ${CaptureSet(prevRefs)}.
|The two sets overlap at ${CaptureSet(overlap)}.""",
pos)

val partRefs = part.deepCaptureSet.elems
val partFootprint = partRefs.footprint.deductSym(sym)
val partHidden = partRefs.hidden.footprint.deductSym(sym) -- partFootprint

checkSep(footprint, partHidden, identity, "references", "hides")
checkSep(hiddenSet, partHidden, _.hidden, "also hides", "hides")
checkSep(hiddenSet, partFootprint, _.hidden, "hides", "references")

footprint ++= partFootprint
hiddenSet ++= partHidden
checked += 1
end for
end checkParts

object traverse extends TypeAccumulator[Captures]:

/** A stack of part lists to check. We maintain this since immediately
* checking parts when traversing the type would check innermost to oputermost.
* But we want to check outermost parts first since this prioritized errors
* that are more obvious.
*/
var toCheck: List[List[Type]] = Nil

private val seen = util.HashSet[Symbol]()

def apply(c: Captures, t: Type) =
if variance < 0 then c
else
val t1 = t.dealias
t1 match
case t @ AppliedType(tycon, args) =>
val c1 = foldOver(Captures.None, t)
if c1 == Captures.NeedsCheck then
toCheck = (tycon :: args) :: toCheck
c.add(c1)
case t @ CapturingType(parent, cs) =>
val c1 = this(c, parent)
if cs.elems.containsHidden then c1.add(Captures.Hidden)
else if !cs.elems.isEmpty then c1.add(Captures.Explicit)
else c1
case t: TypeRef if t.symbol.isAbstractOrParamType =>
if seen.contains(t.symbol) then c
else
seen += t.symbol
apply(apply(c, t.prefix), t.info.bounds.hi)
case t =>
foldOver(c, t)

if !tpe.hasAnnotation(defn.UntrackedCapturesAnnot) then
traverse(Captures.None, tpe)
traverse.toCheck.foreach(checkParts)
end checkType

private def collectMethodTypes(tp: Type): List[TermLambda] = tp match
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@ class Definitions {
@tu lazy val UncheckedCapturesAnnot: ClassSymbol = requiredClass("scala.annotation.unchecked.uncheckedCaptures")
@tu lazy val UntrackedCapturesAnnot: ClassSymbol = requiredClass("scala.caps.untrackedCaptures")
@tu lazy val UseAnnot: ClassSymbol = requiredClass("scala.caps.use")
@tu lazy val ConsumeAnnot: ClassSymbol = requiredClass("scala.caps.consume")
@tu lazy val RefineOverrideAnnot: ClassSymbol = requiredClass("scala.caps.refineOverride")
@tu lazy val VolatileAnnot: ClassSymbol = requiredClass("scala.volatile")
@tu lazy val BeanGetterMetaAnnot: ClassSymbol = requiredClass("scala.annotation.meta.beanGetter")
Expand Down
2 changes: 2 additions & 0 deletions library/src/scala/caps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
*/
final class refineOverride extends annotation.StaticAnnotation

final class consume extends annotation.StaticAnnotation

object unsafe:

extension [T](x: T)
Expand Down
8 changes: 1 addition & 7 deletions tests/neg-custom-args/captures/capt-depfun.check
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt-depfun.scala:10:43 ----------------------------------
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
| ^^^^^^^
| Found: Str^{} ->{ac, y, z} Str^{y, z}
| Required: Str^{y, z} ->{fresh} Str^{y, z}
|
| longer explanation available when compiling with `-explain`
-- Error: tests/neg-custom-args/captures/capt-depfun.scala:10:24 -------------------------------------------------------
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| Separation failure: Str^{y, z} => Str^{y, z} captures a root element hiding {ac, y, z}
| and also refers to {y, z}.
| The two sets overlap at {y, z}
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/capt-depfun.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class Str
def f(y: Cap, z: Cap) =
def g(): C @retains(y, z) = ???
val ac: ((x: Cap) => Str @retains(x) => Str @retains(x)) = ???
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
45 changes: 27 additions & 18 deletions tests/neg-custom-args/captures/sepchecks2.check
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:8:10 ---------------------------------------------------------
8 | println(c) // error
| ^
| Separation failure: Illegal access to {c} which is hidden by the previous definition
| of value xs with type List[box () => Unit].
| This type hides capabilities {xs*, c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:11:33 --------------------------------------------------------
11 | foo((() => println(c)) :: Nil, c) // error
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:10:10 --------------------------------------------------------
10 | println(c) // error
| ^
| Separation failure: Illegal access to {c} which is hidden by the previous definition
| of value xs with type List[box () => Unit].
| This type hides capabilities {xs*, c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:13:33 --------------------------------------------------------
13 | foo((() => println(c)) :: Nil, c) // error
| ^
| Separation failure: argument of type (c : Object^)
| to method foo: (xs: List[box () => Unit], y: Object^): Nothing
Expand All @@ -19,15 +19,24 @@
| Hidden footprint of current argument : {c}
| Declared footprint of current argument: {}
| Undeclared overlap of footprints : {c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:12:10 --------------------------------------------------------
12 | val x1: (Object^, Object^) = (c, c) // error
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:14:10 --------------------------------------------------------
14 | val x1: (Object^, Object^) = (c, c) // error
| ^^^^^^^^^^^^^^^^^^
| Separation failure: (box Object^, box Object^) captures a root element hiding {c}
| and also captures another root element hiding {c}.
| The two sets overlap at {c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:13:10 --------------------------------------------------------
13 | val x2: (Object^, Object^{d}) = (d, d) // error
| Separation failure in type (box Object^, box Object^).
| One part, box Object^ , hides {c}.
| A previous part, box Object^ , also hides {c}.
| The two sets overlap at {c}.
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:15:10 --------------------------------------------------------
15 | val x2: (Object^, Object^{d}) = (d, d) // error
| ^^^^^^^^^^^^^^^^^^^^^
| Separation failure: (box Object^, box Object^{d}) captures a root element hiding {d}
| and also refers to {d}.
| The two sets overlap at {d}
| Separation failure in type (box Object^, box Object^{d}).
| One part, box Object^{d} , references {d}.
| A previous part, box Object^ , hides {d}.
| The two sets overlap at {d}.
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:27:6 ---------------------------------------------------------
27 | bar((c, c)) // error
| ^^^^^^
| Separation failure in the argument's adapted type type (box Object^, box Object^).
| One part, box Object^ , hides {c}.
| A previous part, box Object^ , also hides {c}.
| The two sets overlap at {c}.
11 changes: 10 additions & 1 deletion tests/neg-custom-args/captures/sepchecks2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ import language.future // sepchecks on

def foo(xs: List[() => Unit], y: Object^) = ???

def bar(x: (Object^, Object^)): Unit = ???

def Test(c: Object^) =
val xs: List[() => Unit] = (() => println(c)) :: Nil
println(c) // error

def Test2(c: Object^, d: Object^) =
def Test2(c: Object^, d: Object^): Unit =
foo((() => println(c)) :: Nil, c) // error
val x1: (Object^, Object^) = (c, c) // error
val x2: (Object^, Object^{d}) = (d, d) // error
Expand All @@ -17,3 +19,10 @@ def Test3(c: Object^, d: Object^) =

def Test4(c: Object^, d: Object^) =
val x: (Object^, Object^{c}) = (d, c) // ok

def Test5(c: Object^, d: Object^): Unit =
bar((c, d)) // ok

def Test6(c: Object^, d: Object^): Unit =
bar((c, c)) // error

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import scala.reflect.ClassTag
import annotation.unchecked.{uncheckedVariance, uncheckedCaptures}
import annotation.tailrec
import caps.cap
import language.`3.7` // sepchecks on
import caps.untrackedCaptures
import language.`3.8` // sepchecks on

/** A strawman architecture for new collections. It contains some
* example collection classes and methods with the intent to expose
Expand Down Expand Up @@ -68,11 +69,13 @@ object CollectionStrawMan5 {
/** Base trait for strict collections */
trait Buildable[+A] extends Iterable[A] {
protected def newBuilder: Builder[A, Repr] @uncheckedVariance
override def partition(p: A => Boolean): (Repr, Repr) = {
override def partition(p: A => Boolean): (Repr, Repr) @untrackedCaptures =
// Without untrackedCaptures this fails SepChecks.checkType.
// But this is probably an error in the hiding logic.
// TODO remove @untrackedCaptures and investigate
val l, r = newBuilder
iterator.foreach(x => (if (p(x)) l else r) += x)
(l.result, r.result)
}
// one might also override other transforms here to avoid generating
// iterators if it helps efficiency.
}
Expand Down

0 comments on commit 718e03d

Please sign in to comment.