Skip to content

Commit

Permalink
Method to iterate sub/supertypes of any type.
Browse files Browse the repository at this point in the history
This will simplify the generation of type taxonomies, which, in turn,
needs to be modified so that it includes top and bottom types as
described in issue #94.
  • Loading branch information
nsbgn committed Jun 16, 2022
1 parent 668e5bc commit 42761b0
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion transformation_algebra/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum, auto
from abc import ABC, abstractmethod
from functools import reduce
from itertools import chain
from itertools import chain, count
from inspect import signature
from typing import Optional, Iterator, Iterable, Callable

Expand Down Expand Up @@ -252,6 +252,20 @@ class TypeInstance(Type):
def normalized(self) -> bool:
return not (isinstance(self, TypeVariable) and self.unification)

@abstractmethod
def subtypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
"""
Strict subtypes of this type.
"""
return NotImplemented

@abstractmethod
def supertypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
"""
Strict supertypes of this type.
"""
return NotImplemented

def __str__(self):
return self.text(with_constraints=True)

Expand Down Expand Up @@ -623,6 +637,43 @@ def __hash__(self) -> int:
def basic(self) -> bool:
return self._operator.arity == 0

def subtypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
if recursive:
raise NotImplementedError
op = self._operator
if op is Top:
raise NotImplementedError
elif op is Bottom:
pass
elif op.arity == 0:
if op.children:
yield from (c() for c in op.children)
else:
yield Bottom()
else:
for i, v, p in zip(count(), op.variance, self.params):
for q in (p.subtypes() if Variance.CO else p.supertypes()):
yield op(*(q if i == j else p
for j, p in enumerate(self.params)))

def supertypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
if recursive:
raise NotImplementedError
op = self._operator
if op is Bottom:
raise NotImplementedError
elif op is Top:
pass
elif op.arity == 0:
if op.parent:
yield op.parent()
else:
yield Top()
else:
for i, v, p in zip(count(), op.variance, self.params):
for q in (p.supertypes() if Variance.CO else p.subtypes()):
yield op(*(q if i == j else p
for j, p in enumerate(self.params)))

class TypeVariable(TypeInstance):
"""
Expand All @@ -637,6 +688,12 @@ def __init__(self, wildcard: bool = False, origin=None):
self._constraints: set[Constraint] = set()
self.origin = origin

def subtypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
raise RuntimeError

def supertypes(self, recursive: bool = False) -> Iterator[TypeInstance]:
raise RuntimeError

def check_constraints(self) -> None:
for c in list(self._constraints):
if c.fulfill():
Expand Down

0 comments on commit 42761b0

Please sign in to comment.