Skip to content

Commit 7f5d486

Browse files
authored
Add torchao to optimum as a pytorch backend configuration (#297)
1 parent 66c837b commit 7f5d486

File tree

6 files changed

+28
-10
lines changed

6 files changed

+28
-10
lines changed

docker/cpu/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ ARG TORCH_VERSION=""
3333
ARG TORCH_RELEASE_TYPE=stable
3434

3535
RUN if [ -n "${TORCH_VERSION}" ]; then \
36-
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \
36+
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \
3737
elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \
38-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \
38+
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \
3939
elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \
40-
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu ; \
40+
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/cpu ; \
4141
else \
4242
echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \
4343
fi

docker/cuda-ort/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu118
3232
ARG TORCH_VERSION=stable
3333

3434
RUN if [ "${TORCH_VERSION}" = "stable" ]; then \
35-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
35+
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
3636
elif [ "${TORCH_VERSION}" = "nightly" ]; then \
37-
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
37+
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
3838
else \
39-
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
39+
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
4040
fi
4141

4242
# Install torch-ort and onnxruntime-training

docker/cuda/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu124
3232
ARG TORCH_RELEASE_TYPE=stable
3333

3434
RUN if [ -n "${TORCH_VERSION}" ]; then \
35-
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
35+
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
3636
elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \
37-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
37+
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
3838
elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \
39-
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
39+
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
4040
else \
4141
echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \
4242
fi

examples/pytorch_llama.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
"quantization_scheme": "gptq",
3030
"quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256},
3131
},
32+
"torchao-int4wo-128": {
33+
"torch_dtype": "bfloat16",
34+
"quantization_scheme": "torchao",
35+
"quantization_config": {"quant_type": "int4_weight_only", "group_size": 128},
36+
}
3237
}
3338

3439

optimum_benchmark/backends/pytorch/backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TrainerState,
1717
TrainingArguments,
1818
)
19+
from transformers import TorchAoConfig
1920

2021
from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
2122
from ..base import Backend
@@ -323,6 +324,11 @@ def process_quantization_config(self) -> None:
323324
self.quantization_config = BitsAndBytesConfig(
324325
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
325326
)
327+
elif self.is_torchao_quantized:
328+
self.logger.info("\t+ Processing TorchAO config")
329+
self.quantization_config = TorchAoConfig(
330+
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
331+
)
326332
else:
327333
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")
328334

@@ -366,6 +372,13 @@ def is_awq_quantized(self) -> bool:
366372
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
367373
)
368374

375+
@property
376+
def is_torchao_quantized(self) -> bool:
377+
return self.config.quantization_scheme == "torchao" or (
378+
hasattr(self.pretrained_config, "quantization_config")
379+
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
380+
)
381+
369382
@property
370383
def is_exllamav2(self) -> bool:
371384
return (

optimum_benchmark/backends/pytorch/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
AMP_DTYPES = ["bfloat16", "float16"]
1010
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]
1111

12-
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}}
12+
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}, "torchao": {}}
1313

1414

1515
@dataclass

0 commit comments

Comments
 (0)