Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 25, 2024
1 parent 3fe1fa9 commit e142190
Showing 1 changed file with 65 additions and 75 deletions.
140 changes: 65 additions & 75 deletions simba/main.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,29 @@
import torch
import torch
from torch import nn, Tensor
from zeta.nn import MambaBlock


class EMMImage(nn.Module):
class EinFFT(nn.Module):
"""
EMM (Element-wise Multiplication Module) is a PyTorch module that performs element-wise multiplication
between two tensors.
EinFFT module performs the EinFFT operation on the input tensor.
Args:
None
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dim (int): Dimension of the input tensor.
heads (int, optional): Number of attention heads. Defaults to 8.
Returns:
Tensor: The result of element-wise multiplication between the input tensor `x` and the weight tensor.
Attributes:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dim (int): Dimension of the input tensor.
heads (int): Number of attention heads.
act (nn.SiLU): Activation function (SiLU).
Wr (nn.Parameter): Learnable weight parameter for real part.
Wi (nn.Parameter): Learnable weight parameter for imaginary part.
"""

def __init__(
self,
):
super().__init__()

def forward(self, x: Tensor, weight: Tensor) -> Tensor:
x_b, x_h, x_w, x_c = x.shape

# Weight shape
c_b, c_d, c_d = weight.shape

# Something

# Multiply
return x * weight


class EinFFT(nn.Module):
def __init__(
self,
in_channels: int,
Expand All @@ -46,45 +36,49 @@ def __init__(
self.out_channels = out_channels
self.dim = dim
self.heads = heads

# silu
self.act = nn.SiLU()

# Weights for Wr and Wi
self.Wr = nn.Parameter(
torch.randn(in_channels, out_channels)
)
self.Wi = nn.Parameter(
torch.randn(in_channels, out_channels)
)

# IFFT

def forward(self, x: Tensor):
self.Wr = nn.Parameter(torch.randn(in_channels, out_channels))
self.Wi = nn.Parameter(torch.randn(in_channels, out_channels))

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the EinFFT module.
Args:
x (Tensor): Input tensor of shape (batch_size, in_channels, height, width).
Returns:
Tensor: Output tensor of shape (batch_size, in_channels, height, width).
"""
b, c, h, w = x.shape
# Get Xr and X1

# Get Xr and X1
fast_fouried = torch.fft.fft(x)
print(fast_fouried.shape)

# Get Wr Wi use pytorch split instead
xr = fast_fouried.real
xi = fast_fouried.imag

# Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead
# matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi)
matmul = torch.matmul(xr, xi)
# matmul = torch.matmul(self.Wr, self.Wi)
print(matmul.shape)

# Xr, Xi hat, use torch split instead
xr_hat = matmul.real
xi_hat = matmul.imag
xr_hat = matmul # .real
xi_hat = matmul # .imag

# Silu
acted_xr_hat = self.act(xr_hat)
acted_xi_hat = self.act(xi_hat)

# Emm with the weights use torch split instead
# emmed = torch.matmul(
# acted_xr_hat,
Expand All @@ -94,24 +88,22 @@ def forward(self, x: Tensor):
# self.Wi
# )
emmed = torch.matmul(acted_xr_hat, acted_xi_hat)

# Split up into Xr and Xi again for the ifft use torch split instead
xr_hat = emmed.real
xi_hat = emmed.imag
xr_hat = emmed # .real
xi_hat = emmed # .imag

# IFFT
iffted = torch.fft.ifft(xr_hat + xi_hat)

return iffted



x = torch.randn(1, 3, 64, 64)
einfft = EinFFT(3, 64, 64)

out = einfft(x)
print(out.shape)
print(out)




class Simba(nn.Module):
def __init__(
Expand All @@ -130,44 +122,42 @@ def __init__(
self.d_state = d_state
self.d_conv = d_conv
self.dropout = nn.Dropout(dropout)



# Mamba Block
self.mamba = MambaBlock(
dim = self.dim,
depth = 1,
d_state = self.d_state,
d_conv = self.d_conv,
dim=self.dim,
depth=1,
d_state=self.d_state,
d_conv=self.d_conv,
)



def forward(self, x: Tensor) -> Tensor:
b, s, d = x.shape

residual = x

# Layernorm
normed = nn.LayerNorm(d)(x)

# Mamba
mamba = self.mamba(normed)

# Dropout
droped = self.dropout(mamba)

out = residual + droped

# Phase 2
residual_new = out

# Layernorm
normed_new = nn.LayerNorm(d)(out)

# einfft
fasted = normed_new

# Dropout
out = self.dropout(fasted)

# residual
return out + residual_new
return out + residual_new

0 comments on commit e142190

Please sign in to comment.