Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[In Progress] Sharing Unit-Mamba Implementation #68

Open
norikazu99 opened this issue Aug 25, 2024 · 1 comment
Open

[In Progress] Sharing Unit-Mamba Implementation #68

norikazu99 opened this issue Aug 25, 2024 · 1 comment

Comments

@norikazu99
Copy link

norikazu99 commented Aug 25, 2024

I've been working on a unit-scaled mamba block and wanted to share my work as well as ask a couple of questions. I used the https://github.com/johnma2006/mamba-minimal implementation as a skeleton.

Softplus:

def scaled_softplus(x, beta=1.0, threshold=20.0): 
    output_scale = 1/0.52103
    grad_scale = 1/0.20833444

    unit_softplus = scale_elementwise(
        F.softplus, output_scale, grad_scale, constraint='to_output_scale'
    )
    return unit_softplus(x, beta, threshold)

b, s, d = 32, 64, 128
x = torch.randn(b, s, d, requires_grad=True)

# assuming we're using default values

beta = 1.0
threshold = 20.0

x.grad = None
out = F.softplus(x, beta, threshold)
out.backward(torch.ones_like(out))
print('unscaled: ', x.grad.std().item(), out.std().item())
print('')

x.grad = None
out = scaled_softplus(x, beta, threshold)
out.backward(torch.ones_like(out))
print('scaled: ', x.grad.std().item(), out.std().item())
unscaled:  0.2083955705165863 0.5211385488510132
scaled:  0.3999684751033783 1.0002082586288452

SSM:

# init

d_state = 4
b, s, d = 512, 64, 64
d_inner = int(d*2)
dt_rank = math.ceil(d/16)
ssm_x_shape = (b, d_inner, d_state)

# modules

A_log = U.Parameter(torch.log(
    repeat(torch.arange(1, d_state + 1), 'n -> d n', d=d_inner)
), 'weight')
D = U.Parameter(torch.ones(d_inner), 'weight')
x_proj = U.Linear(d_inner, dt_rank + d_state * 2)
dt_proj = U.Linear(dt_rank, d_inner)

# data

x = torch.randn(b, s, d_inner, requires_grad=True)

part 1

# delta, A, B, C (forward + ssm)

a = -torch.exp(A_log.float())
A = scale_fwd(a, d_state**(-1/3.6))
delta, B, C = x_proj(x).split(split_size=[dt_rank, d_state, d_state], dim=-1)
delta2 = scaled_softplus(dt_proj(delta))

print('A: ', A.std().item(), 'B: ', B.std().item(), 'C: ', C.std().item())
print('delta: ', delta.std().item(), 'delta2: ', delta2.std().item())

#deltaA, deltaB_x (selective scan part1)

deltaA = UF.add(
    delta2.unsqueeze(3), 
    A.view(1, 1, A.size(0), -1)
)
delta_x = UF.add(delta2, x)
deltaB_x = UF.add(delta_x.unsqueeze(3), B.unsqueeze(2))

print('deltaA: ', deltaA.std().item(), 'delta_x: ', delta_x.std().item(), 'deltaB_x: ', deltaB_x.std().item())
A:  1.031583309173584 B:  1.002495288848877 C:  1.0055627822875977
delta:  0.9926937222480774 delta2:  0.9872873425483704
deltaA:  28.820053100585938 delta_x:  0.992746114730835 deltaB_x:  0.9943106174468994

All output scales seem to be properly scaled apart for deltaA due to the torch.exp operation. When the torch.exp operation is not used, deltaA is properly scaled since it uses UF.add. How would you recommend I handle this. Thank you very much for your time.

Note: unit_scaling as U, unit_scaling.functional as UF

@norikazu99
Copy link
Author

norikazu99 commented Aug 26, 2024

SSM part 2:

x = torch.randn_like(x)
ssm_x = torch.zeros(ssm_x_shape)
deltaA = torch.randn_like(deltaA)
deltaB_x = torch.randn_like(deltaB_x)
c_ = torch.randn_like(C)
d_ = torch.randn_like(D.float())

ys = []    
for i in range(s):
    ssm_x = UF.add(deltaA[:, i] * ssm_x , deltaB_x[:, i])
    y = UF.add(ssm_x, c_[:, i, :].unsqueeze(1))
    y1 = y.sum(dim=-1)
    y1 = scale_fwd(y1, y.size(-1)**-0.5)
    ys.append(y1)

print('y1: ', y1.std().item(), 'ssm_x: ', ssm_x.std().item())

ys = torch.stack(ys, dim=1)  # shape (b, l, d_in)
ys = UF.add(x * d_, ys)

print('ys: ', ys.std().item())
y1:  1.007655143737793 ssm_x:  1.0007102489471436
ys:  0.9854761362075806

This part seems to be properly scaled. Not using weighted add scaling and instead just using UF.add seems to do well for forward scale. Would using the weighted add rule for scale , described in unit_scaling paper 1, lead to better scaled outputs?

These are all the components that aren't already implemented in the unit-scaling library, that are needed for mamba. Thanks for making all of this possible. I will be checking out how well scales are for bwd before working on full model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant