Skip to content

Commit

Permalink
Fix the openfold training. (microsoft#4657)
Browse files Browse the repository at this point in the history
This PR removes the bias created as placeholders, which causes a crash
in openfold's training pipeline.

---------

Co-authored-by: Conglong Li <[email protected]>
  • Loading branch information
cctry and conglongli authored Nov 9, 2023
1 parent 3b1cf1f commit 1d1a20c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 44 deletions.
12 changes: 10 additions & 2 deletions csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,18 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
}

CUTLASS_DEVICE
bool set_prologue_done(bool value) { prologue_done_ = value; }
bool set_prologue_done(bool value)
{
prologue_done_ = value;
return true;
}

CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
bool set_zero_outside_bounds(bool value)
{
zero_outside_bounds_ = value;
return true;
}

template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
Expand Down
28 changes: 16 additions & 12 deletions deepspeed/ops/deepspeed4science/evoformer_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _attention(Q, K, V, bias1, bias2):
return O, lse


def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad):
assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
dQ = torch.empty_like(Q, dtype=Q.dtype)
dK = torch.empty_like(K, dtype=K.dtype)
Expand All @@ -44,8 +44,14 @@ def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
delta = torch.empty_like(lse)
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
if bias1_grad:
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
else:
dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device)
if bias2_grad:
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
else:
dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device)
kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)

Expand All @@ -69,10 +75,12 @@ def forward(ctx, q, k, v, bias1=None, bias2=None):
@staticmethod
def backward(ctx, grad_output):
q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2)
if bias1.numel() == 0:
is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3]
is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4]
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad)
if not is_b1_grad:
dB1 = None
if bias2.numel() == 0:
if not is_b2_grad:
dB2 = None
return dQ, dK, dV, dB1, dB2

Expand All @@ -90,13 +98,9 @@ def DS4Sci_EvoformerAttention(Q, K, V, biases):
bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])

if biases[0] is not None:
assert biases[0].shape == bias_1_shape(Q)
else:
biases[0] = Q.new_zeros(bias_1_shape(Q))
assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect"

if biases[1] is not None:
assert biases[1].shape == bias_2_shape(Q)
else:
biases[1] = Q.new_zeros(bias_2_shape(Q))
assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect"

return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])
4 changes: 2 additions & 2 deletions tests/benchmarks/DS4Sci_EvoformerAttention_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team
"""
This script is to test the correctness of the DS4Sci_EvoformerAttention op.
This script is to test the performance of the DS4Sci_EvoformerAttention op.
To run the script,
1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git
2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass
Expand Down Expand Up @@ -83,7 +83,7 @@ def benchmark():
Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=True)
bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=False)
bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True)
# warm up
DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
Expand Down
49 changes: 21 additions & 28 deletions tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,35 @@ def test_DS4Sci_EvoformerAttention(dtype, tensor_shape):
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
bias1 = torch.randn(batch,
n,
1,
1,
seq_len,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
bias2 = torch.randn(batch,
1,
heads,
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
mask = torch.randint(0, 2, (batch, n, 1, 1, seq_len), dtype=dtype, device=get_accelerator().device_name())
mask_bias = 1e9 * (mask - 1)
bias = torch.randn(batch,
1,
heads,
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name())
ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
ref_out = attention_reference(Q, K, V, [mask_bias, bias], 1 / (dim**0.5))
ref_out.backward(dummy_out)
ref_dv, V.grad = V.grad.clone(), None
ref_dk, K.grad = K.grad.clone(), None
ref_dq, Q.grad = Q.grad.clone(), None
ref_db1, bias1.grad = bias1.grad.clone(), None
ref_db2, bias2.grad = bias2.grad.clone(), None
ref_db, bias.grad = bias.grad.clone(), None

out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
out = DS4Sci_EvoformerAttention(Q, K, V, [mask_bias, bias])
out.backward(dummy_out)
dv, v_grad = V.grad.clone(), None
dk, k_grad = K.grad.clone(), None
dq, q_grad = Q.grad.clone(), None
db1, bias1.grad = bias1.grad.clone(), None
db2, bias2.grad = bias2.grad.clone(), None
db, bias.grad = bias.grad.clone(), None

assert torch.allclose(ref_out, out, atol=2e-2, rtol=0), f"\n{ref_out} \n {out}"
assert torch.allclose(ref_dv, dv, atol=2e-2, rtol=0), f"\n{ref_dv} \n {dv}"
assert torch.allclose(ref_dk, dk, atol=2e-2, rtol=0), f"\n{ref_dk} \n {dk}"
assert torch.allclose(ref_dq, dq, atol=2e-2, rtol=0), f"\n{ref_dq} \n {dq}"
assert torch.allclose(ref_db1, db1, atol=2e-2, rtol=1e-2), f"{ref_db1} \n {db1}"
assert torch.allclose(ref_db2, db2, atol=2e-2, rtol=1e-2), f"{ref_db2} \n {db2}"
eps = 1e-2 if dtype == torch.float16 else 5e-2

assert torch.max(torch.abs(ref_out - out)).item() < eps, f"out eps: {torch.max(torch.abs(ref_out - out))}"
assert torch.max(torch.abs(ref_dv - dv)) < eps, f"dv eps: {torch.max(torch.abs(ref_dv - dv))}"
assert torch.max(torch.abs(ref_dk - dk)) < eps, f"dk eps: {torch.max(torch.abs(ref_dk - dk))}"
assert torch.max(torch.abs(ref_dq - dq)) < eps, f"dq eps: {torch.max(torch.abs(ref_dq - dq))}"
assert torch.max(torch.abs(ref_db - db)) < 2 * eps, f"db eps: {torch.max(torch.abs(ref_db - db))}"

0 comments on commit 1d1a20c

Please sign in to comment.