Skip to content

Commit

Permalink
[PROGRESS]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 25, 2024
1 parent 1a5548d commit 4fe2fd3
Showing 1 changed file with 57 additions and 11 deletions.
68 changes: 57 additions & 11 deletions simba/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,69 @@ def __init__(
self.dim = dim
self.heads = heads

# FFT
self.fft = torch.fft.fft()
# silu
self.act = nn.SiLU()

# Weights for Wr and Wi
self.Wr = nn.Parameter(
torch.randn(in_channels, out_channels)
)
self.Wi = nn.Parameter(
torch.randn(in_channels, out_channels)
)

# IFFT

def forward(self, x: Tensor):
x = self.fft(x)
b, c, h, w = x.shape

# Get Xr and X1
# Get Xr and X1
fast_fouried = torch.fft.fft(x)

# Get Wr Wi
xr = fast_fouried
xi = fast_fouried

# EMM
EMM()(xr, xi)
print(fast_fouried.shape)

# Get Wr Wi use pytorch split instead
xr = fast_fouried.real
xi = fast_fouried.imag

# Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead
# matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi)
matmul = torch.matmul(xr, xi)
# matmul = torch.matmul(self.Wr, self.Wi)
print(matmul.shape)

# Xr, Xi hat, use torch split instead
xr_hat = matmul.real
xi_hat = matmul.imag

# Silu
acted_xr_hat = self.act(xr_hat)
acted_xi_hat = self.act(xi_hat)

# Emm with the weights use torch split instead
# emmed = torch.matmul(
# acted_xr_hat,
# self.Wr
# ) + torch.matmul(
# acted_xi_hat,
# self.Wi
# )
emmed = torch.matmul(acted_xr_hat, acted_xi_hat)

# Split up into Xr and Xi again for the ifft use torch split instead
xr_hat = emmed.real
xi_hat = emmed.imag

# IFFT
iffted = torch.fft.ifft(xr_hat + xi_hat)

return iffted

x = torch.randn(1, 3, 64, 64)
einfft = EinFFT(3, 64, 64)

out = einfft(x)
print(out.shape)




Expand Down

0 comments on commit 4fe2fd3

Please sign in to comment.