-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[wip][quantization] incorporate nunchaku
#12207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Let me outline the stage where we're currently at, as this will help us understand the current blockers: This is for quantizing a pre-trained non-quantized model checkpoint as opposed to trying to directly load a quantized checkpoint.
If you have suggestions, please LMK. |
Discussed with @SunMarc internally. We will also try to first supported pre-quantized checkpoints from https://huggingface.co/nunchaku-tech/nunchaku and see how it goes. |
Tried a bit for loading pre-quantized checkpoints. The issues currently are:
Codefrom diffusers import DiffusionPipeline, FluxTransformer2DModel, NunchakuConfig
from nunchaku.models.linear import SVDQW4A4Linear
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch
def modules_without_qweight(safetensors_path: str):
no_qweight = set()
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.endswith(".weight"):
# module name is everything except the last piece after "."
module_name = ".".join(key.split(".")[:-1])
no_qweight.add(module_name)
return sorted(no_qweight)
ckpt_id = "black-forest-labs/FLUX.1-dev"
state_dict_path = hf_hub_download(repo_id="nunchaku-tech/nunchaku-flux.1-dev", filename="svdq-int4_r32-flux.1-dev.safetensors")
modules_to_not_convert = modules_without_qweight(state_dict_path)
# print(f"{modules_to_convert=}")
model = FluxTransformer2DModel.from_single_file(
state_dict_path,
config=ckpt_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
quantization_config=NunchakuConfig(
weight_dtype="int4",
weight_group_size=64,
activation_dtype="int4",
activation_group_size=64,
modules_to_not_convert=modules_to_not_convert
)
).to("cuda")
has_svd = any(isinstance(module, SVDQW4A4Linear) for _, module in model.named_modules())
assert has_svd
pipe = DiffusionPipeline.from_pretrained(
ckpt_id, transformer=model, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"A cat holding a sign that says hello world",
num_inference_steps=50,
guidance_scale=3.5,
generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png") Cc: @SunMarc |
What does this PR do?
Caution
Doesn't work yet.
Test code:
diffusers-cli env
:@lmxyy I am going to outline the stage we're currently at in this integration as that will help us better understand the blockers.
Cc: @SunMarc