Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Aug 28, 2024
2 parents 57ee846 + f1d8909 commit 1a82283
Show file tree
Hide file tree
Showing 14 changed files with 130 additions and 35 deletions.
63 changes: 62 additions & 1 deletion exllamav2/exllamav2_ext/cpp/safetensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,65 @@ void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tenso
remaining -= chunk;
}
}
}
}

void tensor_remap
(
torch::Tensor tensor,
torch::Tensor index
)
{
TORCH_CHECK_SHAPES(tensor, 1, index, 0, 1);
TORCH_CHECK_DTYPE(tensor, kInt);
TORCH_CHECK_DTYPE(index, kInt);

int rows = tensor.size(0);
int cols = tensor.size(1);
uint32_t* temp = (uint32_t*) calloc(cols, sizeof(int));
uint32_t* a = (uint32_t*) tensor.data_ptr();
uint32_t* idx = (uint32_t*) index.data_ptr();

for (int r = 0; r < rows; ++r)
{
memcpy(temp, a, sizeof(uint32_t) * cols);
for (int c = 0; c < cols; ++c)
{
*a++ = temp[idx[c]];
}
}
free(temp);
}

void tensor_remap_4bit
(
torch::Tensor tensor,
torch::Tensor index
)
{
TORCH_CHECK_SHAPES(index, 0, tensor, 1, 8);
TORCH_CHECK_DTYPE(tensor, kInt);
TORCH_CHECK_DTYPE(index, kInt);

int rows = tensor.size(0);
int cols = index.size(0);
uint32_t* temp = (uint32_t*) calloc(cols / 8, sizeof(int));
uint32_t* a = (uint32_t*) tensor.data_ptr();
uint32_t* idx = (uint32_t*) index.data_ptr();

for (int r = 0; r < rows; ++r)
{
memcpy(temp, a, sizeof(uint32_t) * cols / 8);
for (int c = 0; c < cols;)
{
uint32_t rv = 0;
for (int b = 0; b < 8; ++b, ++c)
{
uint32_t i = idx[c];
uint32_t v = (temp[i / 8] >> ((i & 7) * 4) & 0x0f);
rv |= v << (b * 4);
}
*a++ = rv;
}
}
free(temp);
}
13 changes: 13 additions & 0 deletions exllamav2/exllamav2_ext/cpp/safetensors.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,17 @@ uintptr_t safetensors_open_fb(const char* filename);
void safetensors_close_fb(uintptr_t handle);
void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tensor target);

void tensor_remap
(
torch::Tensor tensor,
torch::Tensor index
);

void tensor_remap_4bit
(
torch::Tensor tensor,
torch::Tensor index
);


#endif
21 changes: 12 additions & 9 deletions exllamav2/exllamav2_ext/cuda/graph.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void Graph::attach_label(cudaStream_t stream, int label, int sublabel)
}

template <typename T>
void Graph::update_param(int label, int sublabel, int param, T value)
void Graph::update_param(int label, int sublabel, int param, T value, bool debug)
{
for (int i = 0; i < node_labels.size(); ++i)
{
Expand All @@ -145,19 +145,22 @@ void Graph::update_param(int label, int sublabel, int param, T value)

node_needs_update[i] = true;

// printf("-----------------------------------------------------\n");
// printf("UPDATED:\n");
// DBGI(i);
// inspect_graph();
if (debug)
{
printf("-----------------------------------------------------\n");
printf("UPDATED: ");
DBGI(i);
inspect_graph();
}
}
}

void Graph::update_param_ptr(int label, int sublabel, int param, void* value)
void Graph::update_param_ptr(int label, int sublabel, int param, void* value, bool debug)
{
update_param<void*>(label, sublabel, param, value);
update_param<void*>(label, sublabel, param, value, debug);
}

void Graph::update_param_int(int label, int sublabel, int param, int value)
void Graph::update_param_int(int label, int sublabel, int param, int value, bool debug)
{
update_param<int>(label, sublabel, param, value);
update_param<int>(label, sublabel, param, value, debug);
}
6 changes: 3 additions & 3 deletions exllamav2/exllamav2_ext/cuda/graph.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public:
void attach_label(cudaStream_t stream, int label, int sublabel);

template <typename T>
void update_param(int label, int sublabel, int param, T value);
void update_param(int label, int sublabel, int param, T value, bool debug);

void update_param_ptr(int label, int sublabel, int param, void* value);
void update_param_int(int label, int sublabel, int param, int value);
void update_param_ptr(int label, int sublabel, int param, void* value, bool debug = false);
void update_param_int(int label, int sublabel, int param, int value, bool debug = false);
};


Expand Down
4 changes: 2 additions & 2 deletions exllamav2/exllamav2_ext/cuda/q_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void QMLP::forward_
if (graph->count())
{
graph->begin_capture(stream);
forward_run_(stream, cublas_handle, (half*) x, rows, columns, loras, lora_temp, graph);
forward_run_(stream, cublas_handle, (void*) x, rows, columns, loras, lora_temp, graph);
graph->end_capture(stream);
// printf("**** record ****\n");
// DBGI2(rows, columns);
Expand Down Expand Up @@ -225,7 +225,7 @@ void QMLP::forward_run_

else
{
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, graph, 0);
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, false, NULL, 0, false, graph, 0);
if (layernorm_is_rms)
rms_norm_cuda(stream, temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32, graph, KernelLabels::POST_NORM);
else
Expand Down
2 changes: 2 additions & 0 deletions exllamav2/exllamav2_ext/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("safetensors_pinned_buffer", &safetensors_pinned_buffer, "safetensors_pinned_buffer");
m.def("safetensors_free_pinned_buffer", &safetensors_free_pinned_buffer, "safetensors_free_pinned_buffer");
m.def("safetensors_read_fb", &safetensors_read_fb, "safetensors_read_fb");
m.def("tensor_remap", &tensor_remap, "tensor_remap");
m.def("tensor_remap_4bit", &tensor_remap_4bit, "tensor_remap_4bit");

// qmatrix

Expand Down
4 changes: 2 additions & 2 deletions exllamav2/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def find_msvc():
# gcc / cl.exe flags

if windows:
extra_cflags = ["/Ox", "/openmp"]
extra_cflags = ["/Ox"]
else:
extra_cflags = ["-Ofast", "-fopenmp"]
extra_cflags = ["-Ofast"]

if ext_debug:
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
Expand Down
9 changes: 8 additions & 1 deletion exllamav2/fasttensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def get_tensor(self,
out_dtype = None) -> torch.Tensor:
global global_tensorcache

torch.cuda.synchronize()

if self.tensor_remap and (not_fast or not self.fast):
key = self.tensor_remap[key]

Expand All @@ -211,6 +213,8 @@ def get_tensor(self,
size = end - beg
numel = size // esize
shape = h["shape"]
if device != "cpu":
torch.cuda.set_stream(torch.cuda.default_stream(device))
tensor = torch.zeros(shape, dtype = dtype, device = device)
assert tensor.is_contiguous, "Non-contiguous tensor"
ext_c.safetensors_read_fb(self.handle_fb, beg + self.header_size, size, tensor)
Expand All @@ -224,7 +228,8 @@ def get_tensor(self,
offset = data_offsets[0] + self.header_size
length = data_offsets[1] - data_offsets[0]
assert np.prod(sh) * dts == length, f"Tensor shape doesn't match storage size: {key}"

if device != "cpu":
torch.cuda.set_stream(torch.cuda.default_stream(device))
tensor = torch.empty(sh, device = device, dtype = dtt)
ext_c.safetensors_load(self.handle, tensor, offset, length)

Expand All @@ -236,4 +241,6 @@ def get_tensor(self,
global_tensorcache = global_tensorcache[1:]
global_tensorcache.append((cachekey, tensor))

torch.cuda.synchronize()

return tensor
11 changes: 7 additions & 4 deletions exllamav2/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from exllamav2.compat import safe_move_tensor
from exllamav2.tensor_p import BROADCAST_VC
from exllamav2.util import unpack_4bit, pack_4bit
import gc

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -118,7 +119,7 @@ def load(self,
cfg = self.model.config

if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
if w is None: w = self.load_weight()
if w is None: w = self.load_weight(cpu = output_map is not None)

# Load quantized linear layer from dictionary

Expand All @@ -137,7 +138,7 @@ def load(self,
self.q_tensors = w

if unmap and "q_perm" in w:
perm = w["q_perm"]
perm = w["q_perm"].cpu()
del w["q_perm"]
del w["q_invperm"]
# w["q_perm"] = torch.arange(0, w["q_perm"].shape[-1], dtype = w["q_perm"].dtype, device = w["q_perm"].device)
Expand All @@ -146,8 +147,10 @@ def load(self,
perm = None

if output_map is not None:
w["q_weight"] = w["q_weight"][:, output_map]
w["q_scale"] = pack_4bit(unpack_4bit(w["q_scale"])[:, output_map])
ext_c.tensor_remap(w["q_weight"], output_map)
ext_c.tensor_remap_4bit(w["q_scale"], output_map)
for k in w.keys():
w[k] = safe_move_tensor(w[k], self.device())

self.q_handle = ext.make_q_matrix(w,
self.temp_dq,
Expand Down
4 changes: 4 additions & 0 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,10 @@ def forward_chunk(self,
if self.tp_context:
self.tp_context.wait_streams()

if x is not None and x.is_cuda:
context = self.get_device_context(x.device.index)
torch.cuda.set_stream(context.stream)

# Apply logit scale

# if x is not None and self.config.logit_scale != 1:
Expand Down
2 changes: 1 addition & 1 deletion exllamav2/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def print_options(args):

print_opts = []
if args.gpu_split is not None: print_opts += [f"gpu_split: {args.gpu_split}"]
if args.tensor_parallel is not None: print_opts += ["tensor_parallel"]
if args.tensor_parallel: print_opts += ["tensor_parallel"]
if args.length is not None: print_opts += [f"length: {args.length}"]
if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"]
if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"]
Expand Down
16 changes: 9 additions & 7 deletions exllamav2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def device(self) -> str:
def load_multi(self,
key: str,
keys: list[str],
measure: bool = False) -> int | dict[str: torch.Tensor]:
measure: bool = False,
cpu: bool = False) -> int | dict[str: torch.Tensor]:

tensors = {}
submap = {}
Expand All @@ -85,13 +86,14 @@ def load_multi(self,
if measure:
size += stfile.measure(key + "." + k)
else:
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device())
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu")

return size if measure else tensors


def load_weight(self,
override_key: str | None = None):
override_key: str | None = None,
cpu: bool = False):

if override_key is not None:
keys = [override_key]
Expand All @@ -105,14 +107,14 @@ def load_weight(self,
# EXL2

if key + ".q_weight" in self.model.config.tensor_file_map:
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"])
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], cpu = cpu)
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
return qtensors

# GPTQ

if key + ".qweight" in self.model.config.tensor_file_map:
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"])
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"], cpu = cpu)
if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)):
del qtensors["bias"]
qtensors["scales"] = qtensors["scales"].half()
Expand All @@ -122,14 +124,14 @@ def load_weight(self,

if key + ".weight" in self.model.config.tensor_file_map:
if key + ".bias" in self.model.config.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"])
tensors = self.load_multi(key, ["weight", "bias"], cpu = cpu)
tensor = tensors["weight"].half()
bias = tensors["bias"].half()
if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2:
tensor = tensor.T
return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False)
else:
tensors = self.load_multi(key, ["weight"])
tensors = self.load_multi(key, ["weight"], cpu = cpu)
tensor = tensors["weight"].half()
# if self.model.config.arch.orig_weights_transposed:
# tensor = tensor.T
Expand Down
2 changes: 1 addition & 1 deletion exllamav2/tensor_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def define_split(

# Vocab split

vc_split = [s * 32 for s in integer_split(cfg.vocab_size // 32, gpu_split, 16)]
vc_split = [s * 32 for s in integer_split((cfg.vocab_size + 31) // 32, gpu_split, 16)]

def set_split(raw_split):
b = 0
Expand Down
8 changes: 4 additions & 4 deletions exllamav2/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,19 +291,19 @@ def get_all_gpu_memory():
try:
nvidia_memory = get_nvidia_gpu_memory(visible_devices)
gpu_memory.update(nvidia_memory)
except FileNotFoundError:
except:
pass
# print("nvidia-smi not found. Skipping NVIDIA GPU check.")

try:
amd_memory = get_amd_gpu_memory()
gpu_memory.update(amd_memory)
except FileNotFoundError:
except:
pass
# print("rocm-smi not found. Skipping AMD GPU check.") # TODO: remove warning on NVidia, test on AMD
# print("rocm-smi not found. Skipping AMD GPU check.") # TODO: test on AMD

assert gpu_memory, \
"Unable to read available VRAM from nvidia-smi or rocm-smi"
"Unable to read available VRAM from either nvidia-smi or rocm-smi"

return gpu_memory

Expand Down

0 comments on commit 1a82283

Please sign in to comment.