-
Notifications
You must be signed in to change notification settings - Fork 8
/
tensor_type.py
48 lines (31 loc) · 892 Bytes
/
tensor_type.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Jaxtyping-like typing system for PyTorch
# https://github.com/patrick-kidger/jaxtyping/
from typing import Tuple
class TensorType:
def __init__(self, dtype, *dims):
self.dtype: str = dtype
self.dims: Tuple = dims
def __repr__(self):
return f"{self.dtype}[{', '.join(str(d) for d in self.dims)}]"
class _TypeAnnotation:
def __init__(self, dtype):
self.dtype = dtype
def __getitem__(self, dims):
if not isinstance(dims, tuple):
dims = (dims,)
return TensorType(self.dtype, *dims)
Float32 = _TypeAnnotation("float32")
Int32 = _TypeAnnotation("int32")
if __name__ == "__main__":
x = Float32[10, 20]
print(x)
print(x.dtype)
print(x.dims)
y = Int32[10, 20]
print(y)
print(y.dtype)
print(y.dims)
z = Int32[10]
print(z)
print(z.dtype)
print(z.dims)