diff --git a/transforge/bag.py b/transforge/bag.py index 8ed2d4e..b4230e9 100644 --- a/transforge/bag.py +++ b/transforge/bag.py @@ -5,8 +5,13 @@ from collections.abc import MutableSet class TypeUnion(MutableSet[Type]): + """A disjunction of types. Can be either the most specific types (ie for a + hierarchy with B, C subtypes of A, the disjunction B OR C OR A will be + equivalent to B OR C) or the most general types (ie the disjunction B OR C + OR A will be A).""" - def __init__(self, xs: Iterable[Type] = ()) -> None: + def __init__(self, xs: Iterable[Type] = (), specific: bool = True) -> None: + self.specific = specific self.data: set[Type] = set() for x in xs: self.add(x) @@ -34,10 +39,16 @@ def is_subtype(self, other: Type | "TypeUnion") -> bool: def add(self, new: Type) -> None: to_remove = set() for t in self.data: - if new.is_subtype(t): - to_remove.add(t) - elif t.is_subtype(new): - return + if self.specific: + if new.is_subtype(t): + to_remove.add(t) + elif t.is_subtype(new): + return + else: + if new.is_subtype(t): + return + elif t.is_subtype(new): + to_remove.add(t) self.data -= to_remove self.data.add(new) diff --git a/transforge/query.py b/transforge/query.py index 3c480d9..4d16e26 100644 --- a/transforge/query.py +++ b/transforge/query.py @@ -277,11 +277,23 @@ def io(self) -> Iterator[str]: else: path = "?workflow :output/:type/rdfs:subClassOf*" - yield from union(path, self.graph.objects(output, TF.type)) + # TODO general method for this + type_set = TypeUnion((self.lang.parse_type_uri(t) + for t in self.graph.objects(output, TF.type) + if isinstance(t, URIRef)), + specific=False) + + yield from union(path, (self.lang.uri(t) for t in type_set)) for input in self.graph.objects(self.root, TF.input): + + type_set = TypeUnion((self.lang.parse_type_uri(t) + for t in self.graph.objects(input, TF.type) + if isinstance(t, URIRef)), + specific=False) + yield from union("?workflow :input/:type/rdfs:subClassOf*", - self.graph.objects(input, TF.type)) + (self.lang.uri(t) for t in type_set)) def chronology(self) -> Iterator[str]: """ @@ -338,6 +350,10 @@ def chronology(self) -> Iterator[str]: type_set = TypeUnion(self.lang.parse_type_uri(t) for t in self.type.get(current, ()) if isinstance(t, URIRef)) + type_set = TypeUnion((self.lang.parse_type_uri(t) + for t in self.type.get(current, ()) if isinstance(t, URIRef)), + specific=False) + yield from union(f"{current.n3()} :via", self.operator.get(current, ())) yield from union(f"{current.n3()} :type/rdfs:subClassOf*",