Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add typings #7

Merged
merged 2 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pgproto.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import codecs

class CodecContext:
def get_text_codec(self) -> codecs.CodecInfo: ...
def is_encoding_utf8(self) -> bool: ...

class ReadBuffer: ...
class WriteBuffer: ...
176 changes: 117 additions & 59 deletions types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import builtins
import sys
import typing
import typing_extensions


__all__ = (
'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle',
)

_BS = typing.TypeVar('_BS', bound='BitString')
bryanforbes marked this conversation as resolved.
Show resolved Hide resolved
_P = typing.TypeVar('_P', bound='Point')
_BitOrder = typing_extensions.Literal['big', 'little']


class BitString:
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""

__slots__ = '_bytes', '_bitlength'

def __init__(self, bitstring=None):
def __init__(self,
bitstring: typing.Optional[builtins.bytes] = None) -> None:
if not bitstring:
self._bytes = bytes()
self._bitlength = 0
Expand All @@ -28,7 +39,7 @@ def __init__(self, bitstring=None):
bit_pos = 0

for i, bit in enumerate(bitstring):
if bit == ' ':
if bit == ' ': # type: ignore
continue
bit = int(bit)
if bit != 0 and bit != 1:
Expand All @@ -53,14 +64,15 @@ def __init__(self, bitstring=None):
self._bitlength = bitlen

@classmethod
def frombytes(cls, bytes_=None, bitlength=None):
if bitlength is None and bytes_ is None:
bytes_ = bytes()
bitlength = 0

elif bitlength is None:
bitlength = len(bytes_) * 8

def frombytes(cls: typing.Type[_BS],
bytes_: typing.Optional[builtins.bytes] = None,
bitlength: typing.Optional[int] = None) -> _BS:
if bitlength is None:
if bytes_ is None:
bytes_ = bytes()
bitlength = 0
else:
bitlength = len(bytes_) * 8
else:
if bytes_ is None:
bytes_ = bytes(bitlength // 8 + 1)
Expand All @@ -87,10 +99,10 @@ def frombytes(cls, bytes_=None, bitlength=None):
return result

@property
def bytes(self):
def bytes(self) -> builtins.bytes:
return self._bytes

def as_string(self):
def as_string(self) -> str:
s = ''

for i in range(self._bitlength):
Expand All @@ -100,7 +112,8 @@ def as_string(self):

return s.strip()

def to_int(self, bitorder='big', *, signed=False):
def to_int(self, bitorder: _BitOrder = 'big',
*, signed: bool = False) -> int:
"""Interpret the BitString as a Python int.
Acts similarly to int.from_bytes.

Expand Down Expand Up @@ -135,7 +148,8 @@ def to_int(self, bitorder='big', *, signed=False):
return x

@classmethod
def from_int(cls, x, length, bitorder='big', *, signed=False):
def from_int(cls: typing.Type[_BS], x: int, length: int,
bitorder: _BitOrder = 'big', *, signed: bool = False) -> _BS:
"""Represent the Python int x as a BitString.
Acts similarly to int.to_bytes.

Expand Down Expand Up @@ -187,27 +201,27 @@ def from_int(cls, x, length, bitorder='big', *, signed=False):
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
return cls.frombytes(bytes_, length)

def __repr__(self):
def __repr__(self) -> str:
return '<BitString {}>'.format(self.as_string())

__str__ = __repr__

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, BitString):
return NotImplemented

return (self._bytes == other._bytes and
self._bitlength == other._bitlength)

def __hash__(self):
def __hash__(self) -> int:
return hash((self._bytes, self._bitlength))

def _getitem(self, i):
def _getitem(self, i: int) -> int:
byte = self._bytes[i // 8]
shift = 8 - i % 8 - 1
return (byte >> shift) & 0x1

def __getitem__(self, i):
def __getitem__(self, i: int) -> int:
if isinstance(i, slice):
raise NotImplementedError('BitString does not support slices')

Expand All @@ -216,100 +230,134 @@ def __getitem__(self, i):

return self._getitem(i)

def __len__(self):
def __len__(self) -> int:
return self._bitlength


class Point(tuple):
if typing.TYPE_CHECKING or sys.version_info >= (3, 6):
_PointBase = typing.Tuple[float, float]
_BoxBase = typing.Tuple['Point', 'Point']
_LineBase = typing.Tuple[float, float, float]
_LineSegmentBase = typing.Tuple['Point', 'Point']
_CircleBase = typing.Tuple['Point', float]
else:
# In Python 3.5, subclassing from typing.Tuple does not make the
# subclass act like a tuple in certain situations (like starred
# expressions)
_PointBase = tuple
_BoxBase = tuple
_LineBase = tuple
_LineSegmentBase = tuple
_CircleBase = tuple


class Point(_PointBase):
"""Immutable representation of PostgreSQL `point` type."""

__slots__ = ()

def __new__(cls, x, y):
return super().__new__(cls, (float(x), float(y)))

def __repr__(self):
def __new__(cls,
x: typing.Union[typing.SupportsFloat,
'builtins._SupportsIndex',
typing.Text,
builtins.bytes,
builtins.bytearray],
y: typing.Union[typing.SupportsFloat,
'builtins._SupportsIndex',
typing.Text,
builtins.bytes,
builtins.bytearray]) -> 'Point':
return super().__new__(cls,
typing.cast(typing.Any, (float(x), float(y))))

def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def x(self):
def x(self) -> float:
return self[0]

@property
def y(self):
def y(self) -> float:
return self[1]


class Box(tuple):
class Box(_BoxBase):
"""Immutable representation of PostgreSQL `box` type."""

__slots__ = ()

def __new__(cls, high, low):
return super().__new__(cls, (Point(*high), Point(*low)))
def __new__(cls, high: typing.Sequence[float],
low: typing.Sequence[float]) -> 'Box':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*high),
Point(*low))))

def __repr__(self):
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def high(self):
def high(self) -> Point:
return self[0]

@property
def low(self):
def low(self) -> Point:
return self[1]


class Line(tuple):
class Line(_LineBase):
"""Immutable representation of PostgreSQL `line` type."""

__slots__ = ()

def __new__(cls, A, B, C):
return super().__new__(cls, (A, B, C))
def __new__(cls, A: float, B: float, C: float) -> 'Line':
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))

@property
def A(self):
def A(self) -> float:
return self[0]

@property
def B(self):
def B(self) -> float:
return self[1]

@property
def C(self):
def C(self) -> float:
return self[2]


class LineSegment(tuple):
class LineSegment(_LineSegmentBase):
"""Immutable representation of PostgreSQL `lseg` type."""

__slots__ = ()

def __new__(cls, p1, p2):
return super().__new__(cls, (Point(*p1), Point(*p2)))
def __new__(cls, p1: typing.Sequence[float],
p2: typing.Sequence[float]) -> 'LineSegment':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*p1),
Point(*p2))))

def __repr__(self):
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)

@property
def p1(self):
def p1(self) -> Point:
return self[0]

@property
def p2(self):
def p2(self) -> Point:
return self[1]


Expand All @@ -318,34 +366,44 @@ class Path:

__slots__ = '_is_closed', 'points'

def __init__(self, *points, is_closed=False):
def __init__(self, *points: typing.Sequence[float],
is_closed: bool = False) -> None:
self.points = tuple(Point(*p) for p in points)
self._is_closed = is_closed

@property
def is_closed(self):
def is_closed(self) -> bool:
return self._is_closed

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Path):
return NotImplemented

return (self.points == other.points and
self._is_closed == other._is_closed)

def __hash__(self):
def __hash__(self) -> int:
return hash((self.points, self.is_closed))

def __iter__(self):
def __iter__(self) -> typing.Iterator[Point]:
return iter(self.points)

def __len__(self):
def __len__(self) -> int:
return len(self.points)

def __getitem__(self, i):
@typing.overload
def __getitem__(self, i: int) -> Point:
...

@typing.overload
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
...

def __getitem__(self, i: typing.Union[int, slice]) \
-> typing.Union[Point, typing.Tuple[Point, ...]]:
return self.points[i]

def __contains__(self, point):
def __contains__(self, point: object) -> bool:
return point in self.points


Expand All @@ -354,23 +412,23 @@ class Polygon(Path):

__slots__ = ()

def __init__(self, *points):
def __init__(self, *points: typing.Sequence[float]) -> None:
# polygon is always closed
super().__init__(*points, is_closed=True)


class Circle(tuple):
class Circle(_CircleBase):
"""Immutable representation of PostgreSQL `circle` type."""

__slots__ = ()

def __new__(cls, center, radius):
return super().__new__(cls, (center, radius))
def __new__(cls, center: Point, radius: float) -> 'Circle':
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))

@property
def center(self):
def center(self) -> Point:
return self[0]

@property
def radius(self):
def radius(self) -> float:
return self[1]