Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Exllama kernels support #313

Merged
merged 12 commits into from
Jan 21, 2024
Prev Previous commit
Next Next commit
pseudo-native integration of exllama layers
  • Loading branch information
IlyasMoutawwakil committed Jan 21, 2024
commit c7a281fd8d9303048bed875d7434caf244cfd874
70 changes: 31 additions & 39 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from transformers.modeling_utils import shard_checkpoint

from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.utils.exllama_utils import exllama_post_init, exllamav2_post_init
from awq.modules.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import (
get_named_linears,
set_op_by_name,
Expand Down Expand Up @@ -281,7 +280,14 @@ def from_quantized(
)

# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config.version)
self._load_quantized_modules(
self,
model,
quant_config,
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
)

model.tie_weights()

Expand All @@ -296,16 +302,16 @@ def from_quantized(
dtype=torch_dtype,
)

if use_exllama or use_exllama_v2:
# exllama kernels are used after weights loading because shapes are not the same
# normally this should be avoidable by overriding module.load_state_dict with unpack/pack logic
# but accelerate's load_checkpoint_and_dispatch doesn't use that and sets the weights directly,
# raising an error on shapes mismatch

start = time.time()
self._load_exllama_modules(self, model, use_exllama, use_exllama_v2)
end = time.time()
print(f"Replacing layers with Exllama took {end - start:.2f}s")
# Post Init creates q4 matric handles
import time

start = time.time()
if use_exllama:
model = exllama_post_init(model)
elif use_exllama_v2:
model = exllamav2_post_init(model)
end = time.time()
print(f"Post Init (with pack/unpack) took {end-start:.2f}s")

# Dispath to devices
if fuse_layers:
Expand Down Expand Up @@ -369,9 +375,14 @@ def _load_config(

return model_weights_path, config, quant_config

def _load_quantized_modules(self, model, quant_config, version):
def _load_quantized_modules(
self, model, quant_config, version, use_exllama, use_exllama_v2
):
# Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
assert not (
version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version."

# Get blocks of model
layers = self.get_model_layers(model)
Expand All @@ -392,7 +403,11 @@ def _load_quantized_modules(self, model, quant_config, version):

# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == "GEMM":
if use_exllama:
q_linear_module = WQLinear_Exllama
elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2
elif version == "GEMM":
q_linear_module = WQLinear_GEMM
elif version == "GEMV":
q_linear_module = WQLinear_GEMV
Expand All @@ -406,29 +421,6 @@ def _load_quantized_modules(self, model, quant_config, version):
torch.cuda.empty_cache()
gc.collect()

def _load_exllama_modules(self, model, use_exllama, use_exllama_v2):
if use_exllama:
exllama_module = WQLinear_Exllama
elif use_exllama_v2:
exllama_module = WQLinear_ExllamaV2

gemm_modules = {
name: module
for name, module in model.named_modules()
if isinstance(module, WQLinear_GEMM)
}
for name, gemm_module in gemm_modules.items():
new_submodule = exllama_module.from_wqlinear_gemm(gemm_module)
set_op_by_name(model, name, new_submodule)

if use_exllama:
model = exllama_post_init(model)
elif use_exllama_v2:
model = exllamav2_post_init(model)

torch.cuda.empty_cache()
gc.collect()

@staticmethod
def _scale_activations(self, layer):
scale_dict = self.get_act_for_scaling(layer)
Expand Down
81 changes: 44 additions & 37 deletions awq/modules/exllama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import math
import torch
import torch.nn as nn
from awq.utils.exllama_utils import unpack, pack, awq_reverse_reorder, none_tensor
from awq.utils.exllama_utils import unpack_reorder_pack

import exllama_kernels # with CUDA kernels (AutoAWQ_kernels)


# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")


class WQLinear_Exllama(nn.Module):
QUANT_TYPE: str = "exllama"

Expand All @@ -15,39 +18,42 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for Exllama kernels")

self.q4 = None

self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features

##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
((in_features // 32 * w_bit), out_features),
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(in_features / group_size),
out_features // 32 * w_bit,
),
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
##################################################################################

self.register_buffer(
"scales",
torch.zeros(
(math.ceil(in_features / group_size), out_features),
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)

if bias:
self.register_buffer(
"bias",
Expand All @@ -64,6 +70,9 @@ def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None

self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
self.q4 = exllama_kernels.make_q4(
self.qweight,
self.qzeros,
Expand All @@ -73,38 +82,28 @@ def post_init(self):
)

@classmethod
def from_wqlinear_gemm(cls, q_linear):
exllama_linear = WQLinear_Exllama(
w_bit=q_linear.w_bit,
group_size=q_linear.group_size,
in_features=q_linear.in_features,
out_features=q_linear.out_features,
dev=q_linear.qweight.device,
bias=q_linear.bias,
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear

# Create a new instance of the WQLinear class from ExllamaLinear with the same parameters
bits = q_linear.w_bit
qzeros = q_linear.qzeros
qweight = q_linear.qweight

# Unpack the qweight and qzeros tensors
iweight, izeros = unpack(qweight, qzeros, bits)
# Reverse reorder the iweight and izeros tensors
iweight, izeros = awq_reverse_reorder(iweight, izeros, bits)
# Subtract 1 from the izeros tensor
izeros = torch.bitwise_and(izeros - 1, (2**bits) - 1)
# Pack the qweight and qzeros tensors
qweight, qzeros = pack(iweight, izeros, bits)

# Copy the packed tensors to the ExllamaLinear instance
exllama_linear.scales.copy_(q_linear.scales)
exllama_linear.qweight.copy_(qweight)
exllama_linear.qzeros.copy_(qzeros)

return exllama_linear
raise NotImplementedError("Only inference is supported for Exllama kernels")

def forward(self, x):
assert self.q4 is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)

Expand All @@ -127,3 +126,11 @@ def forward(self, x):
out.add_(self.bias)

return out.view(out_shape)


def exllama_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Exllama):
submodule.post_init()

return model
Loading