Skip to content

Commit 4eb7a37

Browse files
Auto quantization (#313)
1 parent 92cd2b2 commit 4eb7a37

File tree

2 files changed

+14
-50
lines changed

2 files changed

+14
-50
lines changed

optimum_benchmark/backends/pytorch/backend.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@
88
from datasets import Dataset
99
from safetensors.torch import save_file
1010
from transformers import (
11-
AwqConfig,
12-
BitsAndBytesConfig,
13-
GPTQConfig,
14-
TorchAoConfig,
1511
Trainer,
1612
TrainerCallback,
1713
TrainerState,
1814
TrainingArguments,
1915
)
16+
from transformers.quantizers import AutoQuantizationConfig
2017

2118
from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
2219
from ..base import Backend
@@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None:
286283

287284
def process_quantization_config(self) -> None:
288285
if self.is_gptq_quantized:
289-
self.logger.info("\t+ Processing GPTQ config")
290-
291286
try:
292287
import exllamav2_kernels # noqa: F401
293288
except ImportError:
@@ -299,12 +294,7 @@ def process_quantization_config(self) -> None:
299294
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
300295
)
301296

302-
self.quantization_config = GPTQConfig(
303-
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
304-
)
305297
elif self.is_awq_quantized:
306-
self.logger.info("\t+ Processing AWQ config")
307-
308298
try:
309299
import exlv2_ext # noqa: F401
310300
except ImportError:
@@ -316,55 +306,30 @@ def process_quantization_config(self) -> None:
316306
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
317307
)
318308

319-
self.quantization_config = AwqConfig(
320-
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
321-
)
322-
elif self.is_bnb_quantized:
323-
self.logger.info("\t+ Processing BitsAndBytes config")
324-
self.quantization_config = BitsAndBytesConfig(
325-
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
326-
)
327-
elif self.is_torchao_quantized:
328-
self.logger.info("\t+ Processing TorchAO config")
329-
self.quantization_config = TorchAoConfig(
330-
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
331-
)
332-
else:
333-
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")
309+
self.logger.info("\t+ Processing AutoQuantization config")
310+
self.quantization_config = AutoQuantizationConfig.from_dict(
311+
dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
312+
)
334313

335314
@property
336315
def is_quantized(self) -> bool:
337316
return self.config.quantization_scheme is not None or (
338317
hasattr(self.pretrained_config, "quantization_config")
339-
and self.pretrained_config.quantization_config.get("quant_method", None) is not None
340-
)
341-
342-
@property
343-
def is_bnb_quantized(self) -> bool:
344-
return self.config.quantization_scheme == "bnb" or (
345-
hasattr(self.pretrained_config, "quantization_config")
346-
and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb"
318+
and self.pretrained_config.quantization_config.get("quant_method") is not None
347319
)
348320

349321
@property
350322
def is_gptq_quantized(self) -> bool:
351323
return self.config.quantization_scheme == "gptq" or (
352324
hasattr(self.pretrained_config, "quantization_config")
353-
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
325+
and self.pretrained_config.quantization_config.get("quant_method") == "gptq"
354326
)
355327

356328
@property
357329
def is_awq_quantized(self) -> bool:
358330
return self.config.quantization_scheme == "awq" or (
359331
hasattr(self.pretrained_config, "quantization_config")
360-
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
361-
)
362-
363-
@property
364-
def is_torchao_quantized(self) -> bool:
365-
return self.config.quantization_scheme == "torchao" or (
366-
hasattr(self.pretrained_config, "quantization_config")
367-
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
332+
and self.pretrained_config.quantization_config.get("quant_method") == "awq"
368333
)
369334

370335
@property
@@ -376,11 +341,11 @@ def is_exllamav2(self) -> bool:
376341
(
377342
hasattr(self.pretrained_config, "quantization_config")
378343
and hasattr(self.pretrained_config.quantization_config, "exllama_config")
379-
and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2
344+
and self.pretrained_config.quantization_config.exllama_config.get("version") == 2
380345
)
381346
or (
382347
"exllama_config" in self.config.quantization_config
383-
and self.config.quantization_config["exllama_config"].get("version", None) == 2
348+
and self.config.quantization_config["exllama_config"].get("version") == 2
384349
)
385350
)
386351
)
@@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]:
390355
kwargs = {}
391356

392357
if self.config.torch_dtype is not None:
393-
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
358+
if hasattr(torch, self.config.torch_dtype):
359+
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
360+
else:
361+
kwargs["torch_dtype"] = self.config.torch_dtype
394362

395363
if self.is_quantized:
396364
kwargs["quantization_config"] = self.quantization_config

optimum_benchmark/backends/pytorch/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ...system_utils import is_rocm_system
66
from ..config import BackendConfig
77

8-
DEVICE_MAPS = ["auto", "sequential"]
98
AMP_DTYPES = ["bfloat16", "float16"]
109
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]
1110

@@ -60,9 +59,6 @@ def __post_init__(self):
6059
"Please remove it from the `model_kwargs` and set it in the backend config directly."
6160
)
6261

63-
if self.device_map is not None and self.device_map not in DEVICE_MAPS:
64-
raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.")
65-
6662
if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
6763
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")
6864

0 commit comments

Comments
 (0)