8
8
from datasets import Dataset
9
9
from safetensors .torch import save_file
10
10
from transformers import (
11
- AwqConfig ,
12
- BitsAndBytesConfig ,
13
- GPTQConfig ,
14
- TorchAoConfig ,
15
11
Trainer ,
16
12
TrainerCallback ,
17
13
TrainerState ,
18
14
TrainingArguments ,
19
15
)
16
+ from transformers .quantizers import AutoQuantizationConfig
20
17
21
18
from ...import_utils import is_deepspeed_available , is_torch_distributed_available , is_zentorch_available
22
19
from ..base import Backend
@@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None:
286
283
287
284
def process_quantization_config (self ) -> None :
288
285
if self .is_gptq_quantized :
289
- self .logger .info ("\t + Processing GPTQ config" )
290
-
291
286
try :
292
287
import exllamav2_kernels # noqa: F401
293
288
except ImportError :
@@ -299,12 +294,7 @@ def process_quantization_config(self) -> None:
299
294
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
300
295
)
301
296
302
- self .quantization_config = GPTQConfig (
303
- ** dict (getattr (self .pretrained_config , "quantization_config" , {}), ** self .config .quantization_config )
304
- )
305
297
elif self .is_awq_quantized :
306
- self .logger .info ("\t + Processing AWQ config" )
307
-
308
298
try :
309
299
import exlv2_ext # noqa: F401
310
300
except ImportError :
@@ -316,55 +306,30 @@ def process_quantization_config(self) -> None:
316
306
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
317
307
)
318
308
319
- self .quantization_config = AwqConfig (
320
- ** dict (getattr (self .pretrained_config , "quantization_config" , {}), ** self .config .quantization_config )
321
- )
322
- elif self .is_bnb_quantized :
323
- self .logger .info ("\t + Processing BitsAndBytes config" )
324
- self .quantization_config = BitsAndBytesConfig (
325
- ** dict (getattr (self .pretrained_config , "quantization_config" , {}), ** self .config .quantization_config )
326
- )
327
- elif self .is_torchao_quantized :
328
- self .logger .info ("\t + Processing TorchAO config" )
329
- self .quantization_config = TorchAoConfig (
330
- ** dict (getattr (self .pretrained_config , "quantization_config" , {}), ** self .config .quantization_config )
331
- )
332
- else :
333
- raise ValueError (f"Quantization scheme { self .config .quantization_scheme } not recognized" )
309
+ self .logger .info ("\t + Processing AutoQuantization config" )
310
+ self .quantization_config = AutoQuantizationConfig .from_dict (
311
+ dict (getattr (self .pretrained_config , "quantization_config" , {}), ** self .config .quantization_config )
312
+ )
334
313
335
314
@property
336
315
def is_quantized (self ) -> bool :
337
316
return self .config .quantization_scheme is not None or (
338
317
hasattr (self .pretrained_config , "quantization_config" )
339
- and self .pretrained_config .quantization_config .get ("quant_method" , None ) is not None
340
- )
341
-
342
- @property
343
- def is_bnb_quantized (self ) -> bool :
344
- return self .config .quantization_scheme == "bnb" or (
345
- hasattr (self .pretrained_config , "quantization_config" )
346
- and self .pretrained_config .quantization_config .get ("quant_method" , None ) == "bnb"
318
+ and self .pretrained_config .quantization_config .get ("quant_method" ) is not None
347
319
)
348
320
349
321
@property
350
322
def is_gptq_quantized (self ) -> bool :
351
323
return self .config .quantization_scheme == "gptq" or (
352
324
hasattr (self .pretrained_config , "quantization_config" )
353
- and self .pretrained_config .quantization_config .get ("quant_method" , None ) == "gptq"
325
+ and self .pretrained_config .quantization_config .get ("quant_method" ) == "gptq"
354
326
)
355
327
356
328
@property
357
329
def is_awq_quantized (self ) -> bool :
358
330
return self .config .quantization_scheme == "awq" or (
359
331
hasattr (self .pretrained_config , "quantization_config" )
360
- and self .pretrained_config .quantization_config .get ("quant_method" , None ) == "awq"
361
- )
362
-
363
- @property
364
- def is_torchao_quantized (self ) -> bool :
365
- return self .config .quantization_scheme == "torchao" or (
366
- hasattr (self .pretrained_config , "quantization_config" )
367
- and self .pretrained_config .quantization_config .get ("quant_method" , None ) == "torchao"
332
+ and self .pretrained_config .quantization_config .get ("quant_method" ) == "awq"
368
333
)
369
334
370
335
@property
@@ -376,11 +341,11 @@ def is_exllamav2(self) -> bool:
376
341
(
377
342
hasattr (self .pretrained_config , "quantization_config" )
378
343
and hasattr (self .pretrained_config .quantization_config , "exllama_config" )
379
- and self .pretrained_config .quantization_config .exllama_config .get ("version" , None ) == 2
344
+ and self .pretrained_config .quantization_config .exllama_config .get ("version" ) == 2
380
345
)
381
346
or (
382
347
"exllama_config" in self .config .quantization_config
383
- and self .config .quantization_config ["exllama_config" ].get ("version" , None ) == 2
348
+ and self .config .quantization_config ["exllama_config" ].get ("version" ) == 2
384
349
)
385
350
)
386
351
)
@@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]:
390
355
kwargs = {}
391
356
392
357
if self .config .torch_dtype is not None :
393
- kwargs ["torch_dtype" ] = getattr (torch , self .config .torch_dtype )
358
+ if hasattr (torch , self .config .torch_dtype ):
359
+ kwargs ["torch_dtype" ] = getattr (torch , self .config .torch_dtype )
360
+ else :
361
+ kwargs ["torch_dtype" ] = self .config .torch_dtype
394
362
395
363
if self .is_quantized :
396
364
kwargs ["quantization_config" ] = self .quantization_config
0 commit comments