tutel/jit_kernels/sparse.py torch.float16 There is a bug in the calculation: the cuda calculation result is inconsistent with the CPU calculation result and the array is out of bounds #196
Labels
invalid
This doesn't seem right
code :
import numpy as np
import torch
from tutel.jit_kernels import sparse as jit_kernel
print(torch.version)
def moe_dispatch_bwd_gate():
samples=2
capacity=2
hidden=2
num_experts=1
indices = [0,0]
locations = [0,0]
input = [0.4946, -0.0043, 0.5386, -0.8354]
dispatch = [0.7085, 0.8257, -0.1455, -0.1788]
#int32
indices_t = np.asarray(indices,dtype=np.int32)
locations_t = np.asarray(locations,dtype=np.int32)
#float / half
input_t = np.asarray(input,dtype=np.float16)
dispatch_t = np.asarray(dispatch,dtype=np.float16)
indices_gpu = torch.from_numpy(indices_t).cuda()
locations_gpu = torch.from_numpy(locations_t).cuda()
input_gpu = torch.from_numpy(input_t).cuda()
dispatch_gpu = torch.from_numpy(dispatch_t).cuda()
print("cuda:")
print("indices_gpu:",indices_gpu)
print("locations_gpu:",locations_gpu)
print("input_gpu:",input_gpu)
print("dispatch_gpu:",dispatch_gpu)
# call gpu func
grad_gates = torch.zeros([samples], dtype=input_gpu.dtype, device=input_gpu.device)
moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_gpu.dtype, input_gpu.is_cuda)
moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden, capacity])
print("grad_gates:",grad_gates)
# call cpu func
input_t = np.asarray(input,dtype=np.float32)
dispatch_t = np.asarray(dispatch,dtype=np.float32)
indices_cpu = torch.from_numpy(indices_t)
locations_cpu = torch.from_numpy(locations_t)
input_cpu = torch.from_numpy(input_t)
print("cpu:")
# print("input_cpu:",input_cpu)
dispatch_cpu = torch.from_numpy(dispatch_t)
grad_gates_cpu = torch.zeros([samples], dtype=input_cpu.dtype, device=input_cpu.device)
moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_cpu.dtype, input_cpu.is_cuda)
moe_dispatch_bwd_gate(grad_gates_cpu, indices_cpu, locations_cpu, input_cpu, dispatch_cpu, extra=[samples, hidden, capacity])
print("grad_gates_cpu:",grad_gates_cpu)
if name == 'main':
moe_dispatch_bwd_gate()
Problem: cuda calculation result is inconsistent with CPU calculation result:
cuda:[0.4180, 0.0000]
cpu:[ 0.3469, -0.3082]
Cuda calculation process analysis:
When index=0, calculate the gradient of the first gate
Due to dispatched_ Input and reshaded_ Input is of type half2, which is equivalent to float pointer
Therefore, when i=0, the subscript index * (hidden)+i=0 of the distribution, and the subscript index * (hidden)+i=0 of the input, take the first two half data, and accumulate the result of the calculation_ gates1_ s_ On rf
Read value: patch=[0.7085, 0.8257], input=[0.4946, -0.0043]
I=0 Calculation result: grad_ gates1_ s_ rf = 0.7085 * 0.4946 + 0.8257 * (-0.0043) = 0.34687359
When i=1, the subscript index * (hidden)+i=1 of the distribution, and the subscript index * (hidden)+i=1 of the input, take the last two half data, and also add it to the first gate gradient
Read value: patch=[-0.1455, -0.1788], input=[0.5386, -0.8354]
I=1 calculation result grad_ gates1_ s_ rf += (0.5386 * (-0.1455) + (-0.8354) * (-0.1788) = 0.07100322)
Last grad_ gates1_ s_ rf = 0.34687359 + 0.07100322 = 0.41787681
When index=1, the gradient of the second gate is calculated. The initial subscript of input is 2. The array access is out of bounds. The illegal address value may be 0, resulting in the second gradient result of 0
The text was updated successfully, but these errors were encountered: