Skip to content

[wip][quantization] incorporate nunchaku #12207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft

[wip][quantization] incorporate nunchaku #12207

wants to merge 12 commits into from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 21, 2025

What does this PR do?

Caution

Doesn't work yet.

Test code:

from diffusers import DiffusionPipeline, AutoModel, NunchakuConfig
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"
model = AutoModel.from_pretrained(
    ckpt_id, 
    subfolder="transformer",
    torch_dtype=torch.bfloat16, 
    quantization_config=NunchakuConfig()
)
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id, transformer=model, torch_dtype=torch.bfloat16
)
image = pipe(
    "A cat holding a sign that says hello world", 
    num_inference_steps=50, 
    guidance_scale=3.5,
    generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

diffusers-cli env:

- 🤗 Diffusers version: 0.36.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0.dev20250626+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.53.2
- Accelerate version: 1.10.0.dev0
- PEFT version: 0.17.0
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@lmxyy I am going to outline the stage we're currently at in this integration as that will help us better understand the blockers.

Cc: @SunMarc

@sayakpaul
Copy link
Member Author

Let me outline the stage where we're currently at, as this will help us understand the current blockers:

This is for quantizing a pre-trained non-quantized model checkpoint as opposed to trying to directly load a quantized checkpoint.

If you have suggestions, please LMK.

@sayakpaul
Copy link
Member Author

Discussed with @SunMarc internally. We will also try to first supported pre-quantized checkpoints from https://huggingface.co/nunchaku-tech/nunchaku and see how it goes.

@sayakpaul
Copy link
Member Author

Tried a bit for loading pre-quantized checkpoints. The issues currently are:

  • The prequantized checkpoint (example) has mlp_fc* keys which aren't present in our implementation for Flux. This needs to be accounted for.
  • It uses horizontal fusion for attention in the checkpoints -- something we don't support in our implementation yet. This will also need to be accounted for.
Code
from diffusers import DiffusionPipeline, FluxTransformer2DModel, NunchakuConfig
from nunchaku.models.linear import SVDQW4A4Linear
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch 


def modules_without_qweight(safetensors_path: str):
    no_qweight = set()
    with safe_open(safetensors_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            if key.endswith(".weight"):
                # module name is everything except the last piece after "."
                module_name = ".".join(key.split(".")[:-1])
                no_qweight.add(module_name)
    return sorted(no_qweight)

ckpt_id = "black-forest-labs/FLUX.1-dev"
state_dict_path = hf_hub_download(repo_id="nunchaku-tech/nunchaku-flux.1-dev", filename="svdq-int4_r32-flux.1-dev.safetensors")
modules_to_not_convert = modules_without_qweight(state_dict_path)
# print(f"{modules_to_convert=}")

model = FluxTransformer2DModel.from_single_file(
    state_dict_path,
    config=ckpt_id, 
    subfolder="transformer",
    torch_dtype=torch.bfloat16, 
    quantization_config=NunchakuConfig(
        weight_dtype="int4",
        weight_group_size=64,
        activation_dtype="int4",
        activation_group_size=64,
        modules_to_not_convert=modules_to_not_convert
    )
).to("cuda")
has_svd = any(isinstance(module, SVDQW4A4Linear) for _, module in model.named_modules())
assert has_svd

pipe = DiffusionPipeline.from_pretrained(
    ckpt_id, transformer=model, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
    "A cat holding a sign that says hello world", 
    num_inference_steps=50, 
    guidance_scale=3.5,
    generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

Cc: @SunMarc

@lmxyy
Copy link
Contributor

lmxyy commented Aug 22, 2025

the conversion can be found here: https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants