Skip to content

Commit

Permalink
add types in all nn/init.py classes (tinygrad#7002)
Browse files Browse the repository at this point in the history
* add types in batchnorm class

* fix lint error in batchnorm types

* add types to conv1d function

* add types to convtranspose1d func and conv2d, convtranspose2d classes

* add types to all remaining classes

* change conv1d padding type to also accept str

* less is more; only keep non-obvious types

* mkdocs need types
  • Loading branch information
bhavyagada authored Oct 12, 2024
1 parent 2bb6b95 commit f79e05c
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions tinygrad/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import math
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, List
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, make_pair
from tinygrad.nn import optim, state, datasets # noqa: F401
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, mome
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)

def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
shape_mask = [1, -1, *([1]*(x.ndim-2))]
shape_mask: List[int] = [1, -1, *([1]*(x.ndim-2))]
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
Expand All @@ -49,7 +50,7 @@ def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
batch_var = (y*y).mean(axis=reduce_axes)
return batch_mean, batch_var

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
batch_mean, batch_var = self.calc_stats(x)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats and Tensor.training:
Expand All @@ -59,7 +60,7 @@ def __call__(self, x:Tensor):
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
BatchNorm2d = BatchNorm3d = BatchNorm

def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
"""
Applies a 1D convolution over an input signal composed of several input planes.
Expand Down Expand Up @@ -93,22 +94,24 @@ class Conv2d:
print(t.numpy())
```
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, str]=0,
dilation=1, groups=1, bias=True):
self.kernel_size: Tuple[int, ...] = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
if isinstance(padding, str):
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
self.padding = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)]
self.padding: Union[int, List[int]] = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)] #noqa:E501
else: self.padding = padding
self.stride, self.dilation, self.groups = stride, dilation, groups
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)

def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
groups=1, bias=True) -> ConvTranspose2d:
"""
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
Expand Down Expand Up @@ -142,13 +145,14 @@ class ConvTranspose2d(Conv2d):
print(t.numpy())
```
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.output_padding = output_padding

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
dilation=self.dilation, groups=self.groups)

Expand All @@ -168,12 +172,12 @@ class Linear:
print(t.numpy())
```
"""
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features:int, out_features:int, bias=True):
bound = 1 / math.sqrt(in_features)
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
return x.linear(self.weight.transpose(), self.bias)

class GroupNorm:
Expand All @@ -193,12 +197,12 @@ class GroupNorm:
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
Expand All @@ -224,12 +228,12 @@ class InstanceNorm:
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
def __init__(self, num_features:int, eps=1e-5, affine=True):
self.num_features, self.eps = num_features, eps
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
Expand All @@ -251,12 +255,12 @@ class LayerNorm:
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)

def __call__(self, x:Tensor):
def __call__(self, x:Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
Expand All @@ -278,7 +282,7 @@ class LayerNorm2d(LayerNorm):
print(t.mean().item(), t.std().item())
```
"""
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

class RMSNorm:
"""
Expand All @@ -296,9 +300,9 @@ class RMSNorm:
print(norm(t).numpy())
```
"""
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)

def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()

def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight

Expand Down

0 comments on commit f79e05c

Please sign in to comment.