diff --git a/transforge/bag.py b/transforge/bag.py new file mode 100644 index 0000000..4b4fdcb --- /dev/null +++ b/transforge/bag.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from transforge.type import Type +from typing import Iterator, Iterable +from collections.abc import MutableSet + +class TypeUnion(MutableSet[Type]): + + def __init__(self, xs: Iterable[Type] = ()) -> None: + self.data: set[Type] = set() + for x in xs: + self.add(x) + + def __contains__(self, value: object) -> bool: + return value in self.data + + def __iter__(self) -> Iterator[Type]: + return iter(self.data) + + def __len__(self) -> int: + return len(self.data) + + def __repr__(self) -> str: + return repr(self.data) + + def is_subtype(self, other: Type | "TypeUnion") -> bool: + if isinstance(other, Type): + return bool(self) and all(x.is_subtype(other) for x in self.data) + else: + assert isinstance(other, TypeUnion) + return bool(self) and all(all(x.is_subtype(y) for x in self.data) + for y in other.data) + + def add(self, new: Type) -> None: + to_remove = set() + for t in self.data: + if new.is_subtype(t): + to_remove.add(t) + elif t.is_subtype(new): + return + self.data -= to_remove + self.data.add(new) + + def discard(self, item: Type) -> None: + self.data.discard(item) + + +# TODO: sort on type depth, so that the most specific types are checked first +class Bag(object): + """A bag of types is a conjunction of disjunctions (cq intersection of + unions) of concrete types. A bag is unordered and contains no duplicates, + nor even supertypes of types that are already in the bag. That is, if you + have a type A with subtypes B and C, then adding A, and either B or C will + result in a bag containing only either B or C.""" + + def __init__(self) -> None: + self.content: list[TypeUnion] = [] + + def add(self, *new_types: Type): + """Add a type to the bag. If multiple types are given, the union of + these types is added.""" + + # Any disjuncts that are covered already by other types in the bag can + # be dropped + new = TypeUnion(nt for nt in new_types + if any(t.is_subtype(nt) for t in self.content)) + + if not new: + return + + # Conversely, any types in the bag that are obsoleted by the new type + # can also be removed + self.content = [c for c in self.content + if not new.is_subtype(c)] + + self.content.append(new) diff --git a/transforge/query.py b/transforge/query.py index 7a95d0c..d33e8ce 100644 --- a/transforge/query.py +++ b/transforge/query.py @@ -19,6 +19,7 @@ from transforge.lang import Language from transforge.graph import ( TransformationGraph, CyclicTransformationGraphError) +from transforge.bag import Bag def union(prefix: str, subjects: Iterable[Node]) -> Iterator[str]: @@ -232,30 +233,21 @@ def sparql(*elems: str | Iterable[str]) -> str: return result def types(self) -> Iterator[str]: - """ - Conditions for matching on the bag of types used in a query. - """ - - # Only the types that *definitely* occur - types: set[URIRef] = set() - for type_choice in self.type.values(): - if len(type_choice) == 1: - assert isinstance(type_choice[0], URIRef) - types.add(type_choice[0]) - - for type in types: - yield f"?workflow :contains/rdfs:subClassOf* {type.n3()}." - # yield f"{next(self.generator).n3()} {next(self.generator).n3()} - # {type.n3()}." - - # Also include union types. TODO this is temporary until #79 is - # resolved; see also: - # https://github.com/quangis/transformation-algebra/issues/77#issuecomment-1215064807 - for type_choice in self.type.values(): - if len([t for t in type_choice if t not in types]) >= 2: - yield from union( - f"{next(self.generator).n3()} {next(self.generator).n3()}", - type_choice) + """Conditions for matching on the bag of types used in a query.""" + # See also issues 79 and 77 + + bag = Bag() + for tus in self.type.values(): + bag.add(*(self.lang.parse_type_uri(tu) for tu in tus)) + + for ts in bag.content: + if len(ts) == 1: + yield ( + f"?workflow :contains/rdfs:subClassOf* " + f"{self.lang.uri(next(iter(ts))).n3()}.") + elif len(ts) > 1: + yield from union("?workflow :contains/rdfs:subClassOf* ", + (self.lang.uri(t) for t in ts)) def operators(self) -> Iterator[str]: """