From a864bb2934dcd2efba9ab426d3ea74ee66c2903d Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Thu, 11 Jul 2024 13:43:32 +0800 Subject: [PATCH] Migrate SQ and WOQ to INC 3.x API. (#1606) Signed-off-by: changwangss Co-authored-by: Ye, Xinyu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/checkgroup.yml | 34 - .../workflows/script/unitTest/env_setup.sh | 1 + .../text-generation/quantization/README.md | 10 +- .../quantization/llm_quantization_recipes.md | 191 ++--- .../quantization/run_benchmark.sh | 9 +- .../quantization/run_generation_cpu_woq.py | 84 +- .../quantization/run_generation_sq.py | 204 ++--- .../quantization/run_tuning.sh | 10 +- .../examples/finetuning/multi_modal/train.py | 2 +- .../neural_chat/models/model_utils.py | 3 +- .../transformers/llm/evaluation/models.py | 197 ----- .../llm/quantization/autograd/functions.py | 8 +- .../llm/quantization/nn/modules.py | 32 +- .../transformers/llm/quantization/sq_utils.py | 514 ++++++++++++ .../transformers/llm/quantization/utils.py | 712 ++++++++++------- .../transformers/modeling/modeling_auto.py | 742 ++++++++---------- .../transformers/utils/config.py | 357 +++++---- .../transformers/utils/utility.py | 405 ---------- tests/CI/test_quantization.py | 74 +- tests/CI/test_weight_only.py | 32 +- tests/CI/test_weight_only_gpu.py | 4 +- 21 files changed, 1809 insertions(+), 1816 deletions(-) delete mode 100644 intel_extension_for_transformers/transformers/llm/evaluation/models.py create mode 100644 intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 57c7ab30a60..e1f6b0c3735 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -30,40 +30,6 @@ subprojects: - "optimize-unit-test-PR-test" - "Genreate-OptimizeUT-Report" - - id: "NeuralChat Unit Test" - paths: - - ".github/workflows/unit-test-neuralchat.yml" - - ".github/workflows/script/unitTest/run_unit_test_neuralchat.sh" - - "intel_extension_for_transformers/neural_chat/**" - - "requirements.txt" - - "setup.py" - - "intel_extension_for_transformers/transformers/llm/finetuning/**" - - "intel_extension_for_transformers/transformers/llm/quantization/**" - - "intel_extension_for_transformers/transformers/**" - - "intel_extension_for_transformers/langchain/**" - - "!intel_extension_for_transformers/neural_chat/docs/**" - - "!intel_extension_for_transformers/neural_chat/examples/**" - - "!intel_extension_for_transformers/neural_chat/assets/**" - - "!intel_extension_for_transformers/neural_chat/README.md" - checks: - - "neuralchat-unit-test-baseline" - - "neuralchat-unit-test-PR-test" - - "Generate-NeuralChat-Report" - - - id: "Engine Unit Test workflow" - paths: - - ".github/workflows/unit-test-engine.yml" - - "requirements.txt" - - "setup.py" - - intel_extension_for_transformers/transformers/** - - "intel_extension_for_transformers/transformers/runtime/**" - - "!intel_extension_for_transformers/transformers/runtime/kernels/**" - - "!intel_extension_for_transformers/transformers/runtime/third_party/**" - - "!intel_extension_for_transformers/transformers/runtime/docs/**" - checks: - - "engine-unit-test-baseline" - - "engine-unit-test-PR-test" - - "Genreate-Engine-Report" # - id: "Windows Binary Test" # paths: diff --git a/.github/workflows/script/unitTest/env_setup.sh b/.github/workflows/script/unitTest/env_setup.sh index 838e3a4d98d..4afbf606c32 100644 --- a/.github/workflows/script/unitTest/env_setup.sh +++ b/.github/workflows/script/unitTest/env_setup.sh @@ -13,6 +13,7 @@ until [ "$n" -ge 5 ]; do git clone https://github.com/intel/neural-compressor.git /neural-compressor cd /neural-compressor pip install -r requirements.txt + pip install -r requirements_pt.txt python setup.py install && break n=$((n + 1)) sleep 5 diff --git a/examples/huggingface/pytorch/text-generation/quantization/README.md b/examples/huggingface/pytorch/text-generation/quantization/README.md index 212d17bc22a..a511ce72294 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/README.md +++ b/examples/huggingface/pytorch/text-generation/quantization/README.md @@ -36,21 +36,18 @@ OMP_NUM_THREADS= numactl -m -C python ru --model \ --sq \ --output_dir \ # Default is "./saved_results." - --int8 \ --benchmark \ --batch_size 1 # load SQ model quantied by itrex and do benchmark. OMP_NUM_THREADS= numactl -m -C python run_generation_sq.py \ --model \ - --int8 \ --benchmark \ --batch_size 1 # load SQ model quantied configure.json and do benchmark. python run_generation_sq.py \ --model \ --output_dir \ - --int8 \ - --restore \ + --restore_sq_model_from_json \ --benchmark \ --batch_size 1 ``` @@ -68,14 +65,12 @@ python run_generation_sq.py \ --model \ --sq \ --output_dir \ # Default is "./saved_results." - --int8 \ --accuracy \ --batch_size 56 # load SQ model quantied by itrex and do benchmark. python run_generation_sq.py \ --model \ - --int8 \ --accuracy \ --batch_size 56 @@ -83,8 +78,7 @@ python run_generation_sq.py \ python run_generation_sq.py \ --model \ --output_dir \ - --int8 \ - --restore \ + --restore_sq_model_from_json \ --accuracy \ --batch_size 56 diff --git a/examples/huggingface/pytorch/text-generation/quantization/llm_quantization_recipes.md b/examples/huggingface/pytorch/text-generation/quantization/llm_quantization_recipes.md index 4fd7a3cda14..f3c697d0680 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/llm_quantization_recipes.md +++ b/examples/huggingface/pytorch/text-generation/quantization/llm_quantization_recipes.md @@ -41,9 +41,10 @@ pip install -v . cd examples/huggingface/pytorch/text-generation/quantization pip install -r requirements.txt pip install neural-compressor==2.6 -pip install transformers==4.38.1 - pip install torch==2.3.0+cpu --index-url https://download.pytorch.org/whl/cpu +# 4.38.1 is only limited by smoothquant +pip install transformers==4.38.1 +# ipex is only necessary for smoothquant pip install intel-extension-for-pytorch==2.3.0 ``` @@ -57,11 +58,10 @@ pip install intel-extension-for-pytorch==2.3.0 python run_generation_sq.py \ --model EleutherAI/gpt-j-6b \ --output_dir ./saved_results \ - --trust_remote_code \ - --fallback_add \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.85 ``` @@ -84,7 +84,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -98,7 +98,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -112,10 +112,10 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model facebook/opt-1.3b \ --output_dir ./saved_results \ - --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.9 ``` @@ -138,7 +138,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 128 \ --accuracy @@ -151,7 +151,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -167,8 +167,9 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.5 ``` @@ -191,7 +192,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -204,7 +205,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -218,15 +219,17 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model meta-llama/Llama-2-7b-hf \ --output_dir ./saved_results \ - --trust_remote_code \ - --calib_len 2048 \ - --fallback_add \ - --calib_shuffle False \ - --calib_iters 512 \ --tasks lambada_openai \ - --int8 --sq --accuracy \ + --sq \ + --accuracy \ --batch_size 1 \ - --recipes "{'smooth_quant': True, 'smooth_quant_args': {'alpha': 'auto', 'folding': False, 'default_alpha': 0.8, 'auto_alpha_args': {'alpha_min': 0.79, 'alpha_max': 0.99, 'alpha_step': 0.01, 'shared_criterion': 'mean'}}}" + --init_alpha 0.8 \ + --alpha_min 0.8 \ + --alpha_max 0.99 \ + --alpha_step 0.01 \ + --shared_criterion mean \ + --seq_len 2048 \ + --alpha auto ``` ### Weight-Only Quantization @@ -248,7 +251,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -261,7 +264,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -275,15 +278,17 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model meta-llama/Llama-2-13b-hf \ --output_dir ./saved_results \ - --trust_remote_code \ - --calib_len 1024 \ - --fallback_add \ - --calib_iters 512 - --calib_padding \ + --seq_len 1024 \ --tasks lambada_openai \ - --int8 --sq --accuracy \ + --sq \ + --accuracy \ --batch_size 1 \ - --recipes "{'smooth_quant': True, 'smooth_quant_args': {'alpha': 'auto', 'folding': False, 'default_alpha': 0.8, 'auto_alpha_args': {'alpha_min': 0.75, 'alpha_max': 0.99, 'alpha_step': 0.01, 'shared_criterion': 'max', 'n_samples':64}}}" + --init_alpha 0.8 \ + --alpha_min 0.75 \ + --alpha_max 0.99 \ + --alpha_step 0.01 \ + --shared_criterion max \ + --alpha auto ``` ### Weight-Only Quantization @@ -305,7 +310,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -318,7 +323,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -332,10 +337,10 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model meta-llama/Llama-2-70b-hf \ --output_dir ./saved_results \ - --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.8 ``` @@ -358,7 +363,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -371,7 +376,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -387,8 +392,9 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.9 ``` @@ -411,7 +417,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -424,7 +430,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -440,7 +446,7 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ + --sq --accuracy \ --batch_size 1 \ --alpha 0.95 ``` @@ -463,7 +469,7 @@ python run_generation_cpu_woq.py \ --woq_algo GPTQ \ --bits 4 \ --weight_dtype int4 \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -476,7 +482,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -492,8 +498,9 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.95 ``` @@ -516,7 +523,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -529,7 +536,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -545,9 +552,10 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ - --alpha 1.0 + --sq \ + --accuracy \ + --eval_batch_size 56 \ + --alpha 0.65 ``` ### Weight-Only Quantization @@ -569,7 +577,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -582,7 +590,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -598,7 +606,7 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ + --sq --accuracy \ --batch_size 1 \ --alpha 0.5 ``` @@ -623,7 +631,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -636,7 +644,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -652,11 +660,10 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --calib_iters 512 - --batch_size 1 \ - --recipes "{'smooth_quant':True,'smooth_quant_args':{'alpha':'auto','folding':False,'default_alpha':0.7,'auto_alpha_args':{'alpha_min':0.55,'alpha_max':0.8,'alpha_step':0.01,'shared_criterion':'mean','n_samples':64}}}" \ - --calib_iters 512 + --sq \ + --accuracy \ + --eval_batch_size 56 \ + --alpha 0.75 ``` ### Weight-Only Quantization @@ -677,7 +684,7 @@ python run_generation_cpu_woq.py \ --woq_algo GPTQ \ --bits 4 \ --weight_dtype int4 \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme asym \ --group_size 32 \ --accuracy @@ -690,7 +697,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -706,9 +713,10 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --recipes "{'smooth_quant':True,'smooth_quant_args':{'alpha':'auto','folding':False,'default_alpha':0.85,'auto_alpha_args':{'alpha_min':0.79,'alpha_max':0.88,'alpha_step':0.01,'shared_criterion':'mean'}}}" \ - --batch_size 1 + --sq \ + --accuracy \ + --eval_batch_size 56 \ + --alpha 0.9 ``` ### Weight-Only Quantization @@ -729,10 +737,10 @@ python run_generation_cpu_woq.py \ --woq_algo GPTQ \ --bits 4 \ --weight_dtype int4 \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ - --nsamples 256 \ + --n_samples 256 \ --accuracy # int4 AutoRound @@ -743,7 +751,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -757,10 +765,10 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model bigscience/bloom-1b7 \ --output_dir ./saved_results \ - --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.6 ``` @@ -783,7 +791,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -796,7 +804,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -810,10 +818,10 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model EleutherAI/gpt-neox-20b \ --output_dir ./saved_results \ - --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.7 ``` @@ -835,7 +843,7 @@ python run_generation_cpu_woq.py \ --woq_algo GPTQ \ --bits 4 \ --weight_dtype int4 \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme asym \ --group_size 32 \ --accuracy @@ -848,7 +856,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -864,8 +872,9 @@ python run_generation_sq.py \ --output_dir ./saved_results \ --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.75 ``` @@ -888,7 +897,7 @@ python run_generation_cpu_woq.py \ --bits 4 \ --weight_dtype int4 \ --desc_act \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 32 \ --accuracy @@ -901,7 +910,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy @@ -915,10 +924,10 @@ python run_generation_cpu_woq.py \ python run_generation_sq.py \ --model databricks/dolly-v2-12b \ --output_dir ./saved_results \ - --trust_remote_code \ --tasks lambada_openai \ - --int8 --sq --accuracy \ - --batch_size 1 \ + --sq \ + --accuracy \ + --eval_batch_size 56 \ --alpha 0.75 ``` @@ -940,7 +949,7 @@ python run_generation_cpu_woq.py \ --woq_algo GPTQ \ --bits 4 \ --weight_dtype int4 \ - --max_input_length 2048 \ + --seq_len 2048 \ --scheme sym \ --group_size 128 \ --accuracy @@ -953,7 +962,7 @@ python run_generation_cpu_woq.py \ --woq_algo AutoRound \ --bits 4 \ --weight_dtype int4 \ - --calib_iters 200 \ + --autoround_iters 200 \ --scheme asym \ --group_size 128 \ --accuracy diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh index 61cd923588b..1ed8c54b1ce 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh @@ -71,9 +71,11 @@ function run_benchmark { if [[ ${mode} == "accuracy" ]]; then mode_cmd=" --accuracy " extra_cmd=$extra_cmd" --tasks ${lm_eval_tasks}" + extra_cmd=$extra_cmd" --eval_batch_size ${batch_size}" elif [[ ${mode} == "benchmark" ]]; then mode_cmd=" --benchmark " - extra_cmd=$extra_cmd" --iters ${iters}" + extra_cmd=$extra_cmd" --benchmark_iters ${iters}" + extra_cmd=$extra_cmd" --benchmark_batch_size ${batch_size}" else echo "Error: No such mode: ${mode}" exit 1 @@ -237,9 +239,6 @@ function run_benchmark { fi fi if [[ ${int8} == "true" ]] && [[ "$model_source" != "huggingface" ]]; then - if [[ "${script}" == "run_generation_sq.py" ]] && [[ "${topology}" != "gpt_j_mp" ]];then - extra_cmd=$extra_cmd" --int8" - fi model_name_or_path=$tuned_checkpoint fi if [[ $backend == "neuralspeed" ]]; then @@ -250,13 +249,11 @@ function run_benchmark { if [ "${script}" == "run_generation_sq.py" ];then python -u ./${script} \ --model ${model_name_or_path} \ - --batch_size ${batch_size} \ ${mode_cmd} \ ${extra_cmd} elif [ "${script}" == "run_generation_cpu_woq.py" ];then python -u ./${script} \ --model ${model_name_or_path} \ - --batch_size ${batch_size} \ ${mode_cmd} \ ${extra_cmd} else diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py index 2a7c9194c2c..cd59d9c4086 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py @@ -31,11 +31,12 @@ parser.add_argument("--use_ipex", action="store_true") # ============Benchmark configs============== parser.add_argument("--benchmark", action="store_true") -parser.add_argument("--iters", default=100, type=int, help="num iter") +parser.add_argument("--benchmark_iters", default=100, type=int, help="num iters for benchmark") +parser.add_argument("--benchmark_batch_size", default=1, type=int, help="batch size for benchmark") parser.add_argument("--num_warmup", default=10, type=int, help="num warmup") # ============Accuracy configs============== parser.add_argument("--accuracy", action="store_true") -parser.add_argument("--batch_size", default=56, type=int, help="batch size num.") +parser.add_argument("--eval_batch_size", default=56, type=int, help="batch size num for evaluation.") parser.add_argument( "--tasks", default="lambada_openai", @@ -92,7 +93,21 @@ action="store_true", help="Use layer wise to do quantization", ) -parser.add_argument("--woq_loading", action="store_true") +parser.add_argument( + "--n_samples", type=int, default=512, help="Number of calibration data samples." +) +parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Calibration dataset sequence max length, this should align with your model config", +) +parser.add_argument( + "--batch_size", + type=int, + default=8, + help="Calibration batchsize.", +) # ============GPTQ configs============== parser.add_argument( "--desc_act", @@ -116,33 +131,12 @@ default=128, help="Block size. sub weight matrix size to run GPTQ.", ) -parser.add_argument( - "--nsamples", type=int, default=512, help="Number of calibration data samples." -) -parser.add_argument( - "--max_input_length", - type=int, - default=2048, - help="Calibration dataset sequence max length, this should align with your model config", -) parser.add_argument( "--static_groups", action="store_true", help="Use determined group to do quantization", ) # ============AUTOROUND configs============== -parser.add_argument( - "--calib_len", - type=int, - default=2048, - help="Calibration dataset sequence max length, this should align with your model config", -) -parser.add_argument( - "--calib_iters", - type=int, - default=200, - help="Calibration inference iterations", -) parser.add_argument( "--lr", type=float, @@ -155,11 +149,17 @@ default=None, help="minmax learning rate, if None,it will beset to be the same with lr", ) +parser.add_argument("--autoround_iters", default=200, type=int, help="num iters for autoround calibration.") parser.add_argument( "--disable_quanted_input", action="store_true", help="whether to use the output of quantized block to tune the next block", ) +parser.add_argument( + "--quant_lm_head", + action="store_true", + help="whether to quant the lm head layer", +) # ============BitsAndBytes configs============== parser.add_argument("--bitsandbytes", action="store_true") @@ -231,11 +231,12 @@ bits=args.bits, zero_point=False if args.scheme == "sym" else True, group_size=args.group_size, - max_input_length=args.max_input_length, + seq_len=args.seq_len, + n_samples=args.n_samples, + batch_size=args.batch_size, compute_dtype=args.compute_dtype, scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, - calib_iters=args.calib_iters, use_ipex=args.use_ipex, ) elif args.woq_algo == "Teq": @@ -245,11 +246,12 @@ bits=args.bits, sym=True if args.scheme == "sym" else False, group_size=args.group_size, - max_input_length=args.max_input_length, + seq_len=args.seq_len, + batch_size=args.batch_size, + n_samples=args.n_samples, compute_dtype=args.compute_dtype, scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, - calib_iters=args.calib_iters, use_ipex=args.use_ipex, ) elif args.woq_algo == "GPTQ": @@ -261,14 +263,14 @@ damp_percent=args.damp_percent, sym=True if args.scheme == "sym" else False, blocksize=args.blocksize, - nsamples=args.nsamples, static_groups=args.static_groups, group_size=args.group_size, - max_input_length=args.max_input_length, + n_samples=args.n_samples, + seq_len=args.seq_len, + batch_size=args.batch_size, compute_dtype=args.compute_dtype, scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, - calib_iters=args.calib_iters, layer_wise=args.layer_wise, true_sequential=args.true_sequential, use_ipex=args.use_ipex, @@ -279,16 +281,17 @@ dataset=args.dataset, bits=args.bits, sym=True if args.scheme == "sym" else False, - nsamples=args.nsamples, + n_samples=args.n_samples, group_size=args.group_size, compute_dtype=args.compute_dtype, scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, - iters=args.calib_iters, - calib_len=args.calib_len, + iters=args.autoround_iters, + seq_len=args.seq_len, lr=args.lr, minmax_lr=args.minmax_lr, disable_quanted_input=args.disable_quanted_input, + quant_lm_head = args.quant_lm_head, use_ipex=args.use_ipex, ) else: @@ -339,6 +342,7 @@ _commit_hash=args._commit_hash, use_neural_speed=args.use_neural_speed, ) + user_model = user_model.eval() if hasattr(user_model, "eval") else user_model prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun." input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1) @@ -346,7 +350,7 @@ # start total_time = 0.0 - num_iter = args.iters + num_iter = args.benchmark_iters num_warmup = args.num_warmup total_token_num = 0 eos_token_id = tokenizer.eos_token_id @@ -356,7 +360,7 @@ # tokenizer for chatglm2. if hasattr(tokenizer, "build_chat_input"): input_ids = tokenizer.build_chat_input(prompt)["input_ids"] - input_ids = input_ids.repeat(args.batch_size, 1) + input_ids = input_ids.repeat(args.benchmark_batch_size, 1) eos_token_id = [ tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), @@ -366,11 +370,11 @@ elif hasattr(tokenizer, "build_prompt"): build_prompt = tokenizer.build_prompt(prompt) input_ids = tokenizer( - [build_prompt] * args.batch_size, return_tensors="pt" + [build_prompt] * args.benchmark_batch_size, return_tensors="pt" ).input_ids else: input_ids = tokenizer( - [prompt] * args.batch_size, return_tensors="pt" + [prompt] * args.benchmark_batch_size, return_tensors="pt" ).input_ids gen_ids = user_model.generate( input_ids, @@ -399,11 +403,11 @@ model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code) if args.use_neural_speed: model_args += ",model_format=neural_speed" - args = LMEvalParser(model = "hf", + args = LMEvalParser(model = "hf", model_args=model_args, tasks = args.tasks, device = "cpu", - batch_size = args.batch_size) + batch_size = args.eval_batch_size) results = evaluate(args) for task_name in args.tasks.split(","): if task_name == "wikitext": diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_sq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_sq.py index 976af12a333..fd727af4d53 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_sq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_sq.py @@ -1,21 +1,21 @@ import argparse +import json import os import re import time -import json + import torch +from optimum.intel.generation.modeling import TSModelForCausalLM from transformers import AutoConfig, AutoTokenizer -from intel_extension_for_transformers.transformers import ( - AutoModelForCausalLM, - AutoModel, -) from transformers.utils import check_min_version -from intel_extension_for_transformers.transformers.utils import str2bool -from optimum.intel.generation.modeling import TSModelForCausalLM + from intel_extension_for_transformers.transformers import ( + AutoModel, + AutoModelForCausalLM, MixedPrecisionConfig, SmoothQuantConfig, ) +from intel_extension_for_transformers.transformers.utils import str2bool parser = argparse.ArgumentParser() parser.add_argument("--model", default=None) @@ -34,7 +34,7 @@ help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", ) parser.add_argument( - "--restore", + "--restore_sq_model_from_json", action="store_true", help="restore ipex quantized model from output_dir/best_configure.json", ) @@ -43,11 +43,12 @@ ) # ============Benchmark configs============== parser.add_argument("--benchmark", action="store_true") -parser.add_argument("--iters", default=100, type=int, help="num iter") +parser.add_argument("--benchmark_iters", default=100, type=int, help="num iter") +parser.add_argument("--benchmark_batch_size", default=1, type=int, help="batch size for benchmark") parser.add_argument("--num_warmup", default=10, type=int, help="num warmup") # ============Accuracy configs============== parser.add_argument("--accuracy", action="store_true") -parser.add_argument("--batch_size", default=56, type=int, help="batch size num.") +parser.add_argument("--eval_batch_size", default=56, type=int, help="batch size num.") parser.add_argument( "--tasks", default="lambada_openai", @@ -58,33 +59,32 @@ parser.add_argument("--mixed_precision", action="store_true") # ============SmoothQuant configs============== parser.add_argument("--sq", action="store_true") -parser.add_argument("--calib_iters", default=100, type=int, help="Calibration iters.") +parser.add_argument("--alpha", default=0.5, help="Smooth quant parameter.") parser.add_argument( - "--calib_padding", action="store_true", help="Calibration dataset do padding." + "--n_samples", default=100, type=int, help="Smooth quant calibration samples." ) parser.add_argument( - "--calib_shuffle", - default=True, - type=str2bool, - help="Calibration dataset do shuffle.", + "--seq_len", default=512, type=int, help="Smooth quant calibration input length." ) +parser.add_argument("--batch_size", default=1, type=int, help="batch size num.") +parser.add_argument("--padding", action="store_true") +parser.add_argument("--shuffle", action="store_true") +# sq alpha "auto" parameters +parser.add_argument("--scale_sharing", action="store_true") parser.add_argument( - "--calib_pad_val", default=1, type=int, help="Calibration dataset padding value." + "--init_alpha", default=0.5, type=float, help="Smooth quant parameter." ) parser.add_argument( - "--calib_len", - default=512, - type=int, - help="Calibration dataset max or padding max length.", + "--alpha_min", default=0.0, type=float, help="Smooth quant parameter." ) parser.add_argument( - "--recipes", type=str, help="A dictionary as a string, recipes for smoothquant." + "--alpha_max", default=1.0, type=float, help="Smooth quant parameter." ) -parser.add_argument("--alpha", default="0.5", help="Smooth quant parameter.") parser.add_argument( - "--fallback_add", action="store_true", help="Whether to fallback add ops to FP32" + "--alpha_step", default=0.1, type=float, help="Smooth quant parameter." ) - +parser.add_argument("--shared_criterion", default="max", type=str) +parser.add_argument("--do_blockwise", action="store_true") # ============AutoModel parameters============== parser.add_argument("--_commit_hash", default=None, type=str) parser.add_argument("--trust_remote_code", action="store_true") @@ -106,12 +106,7 @@ config = AutoConfig.from_pretrained( args.model, torchscript=( - True - if ( - args.sq - or (args.int8 or args.int8_bf16_mixed) - ) - else False + True if args.sq else False ), # torchscript will force `return_dict=False` to avoid jit errors use_cache=True, # to use kv cache. trust_remote_code=args.trust_remote_code, @@ -142,56 +137,24 @@ if args.mixed_precision: quantization_config = MixedPrecisionConfig(dtype="bfloat16") # default is bfloat16 elif args.sq: - if re.search("gptj", config.model_type) or re.search("gpt_neox", config.model_type): - op_type_dict = { - "add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}, - } - elif re.search("mpt", config.model_type): - op_type_dict = { - "add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}, - "": { - "weight": {"dtype": ["fp32"]}, - "activation": {"dtype": ["fp32"]}, - }, - } - elif re.search("mistral", config.model_type) or re.search( - "baichuan", config.model_type - ): - op_type_dict = {".*": {"activation": {"algorithm": "minmax"}}} - else: - op_type_dict = {} - if args.fallback_add: - op_type_dict["add"] = { - "weight": {"dtype": ["fp32"]}, - "activation": {"dtype": ["fp32"]}, - } - excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"] - if args.recipes: - try: - import ast - - recipes = ast.literal_eval(args.recipes) - print("Parsed recipes dictionary:", recipes) - except ValueError as e: - print("Error parsing recipes dictionary:", e) - else: - recipes = { - "smooth_quant": True, - "smooth_quant_args": { - "alpha": args.alpha if args.alpha == "auto" else float(args.alpha) - }, - } + excluded_precisions = ["bf16"] quantization_config = SmoothQuantConfig( - tokenizer=tokenizer, # either two of one, tokenizer or calib_func - recipes=recipes, - op_type_dict=op_type_dict, # default is {} - excluded_precisions=excluded_precisions, # default is [] + tokenizer=tokenizer, + seq_len=args.seq_len, + n_samples=args.n_samples, + batch_size=args.batch_size, + excluded_precisions=excluded_precisions, + alpha=args.alpha if args.alpha == "auto" else float(args.alpha), + scale_sharing=args.scale_sharing, + init_alpha=args.init_alpha, + alpha_min=args.alpha_min, + alpha_max=args.alpha_max, + alpha_step=args.alpha_step, + shared_criterion=args.shared_criterion, + do_blockwise=args.do_blockwise, + shuffle=args.shuffle, + padding=args.padding, num_beams=generate_kwargs["num_beams"], - calib_shuffle=args.calib_shuffle, - calib_iters=args.calib_iters, - calib_padding=args.calib_padding, - calib_len=args.calib_len, - calib_pad_val=args.calib_pad_val, ) else: print("The quantization_config is None.") @@ -203,52 +166,41 @@ quantization_config=quantization_config, trust_remote_code=args.trust_remote_code, _commit_hash=args._commit_hash, - use_neural_speed=False ) # save model if args.output_dir is not None and (args.sq or args.mixed_precision): tokenizer.save_pretrained(args.output_dir) if args.sq: + quantization_config.remove_redundant_parameters() + config.quantization_config = quantization_config config.save_pretrained(args.output_dir) user_model.save(args.output_dir) + user_model = AutoModelForCausalLM.from_pretrained( + args.output_dir, + trust_remote_code=args.trust_remote_code, + _commit_hash=args._commit_hash, + ) elif args.mixed_precision: user_model.save_pretrained(args.output_dir) - args.model = args.output_dir -if args.int8 or args.int8_bf16_mixed: - print("Loading SmoothQuant model from: ", args.model) - import intel_extension_for_pytorch as ipex - from intel_extension_for_transformers.transformers.llm.evaluation.models import ( - TSModelCausalLMForITREX, +if args.restore_sq_model_from_json: + from intel_extension_for_transformers.transformers.llm.quantization.sq_utils import ( + recover_model_from_json, + ) + user_model = recover_model_from_json( + args.model, + os.path.join(args.output_dir, "qconfig.json"), + args.trust_remote_code, ) - if args.restore: - from intel_extension_for_transformers.transformers.utils.utility import ( - recover_model_from_json, - ) - user_model = recover_model_from_json( - args.model, - os.path.join(args.output_dir, "best_configure.json"), - args.trust_remote_code, - ) - else: - user_model = torch.jit.load(os.path.join( args.model, "best_model.pt")) - config = AutoConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) - origin_model_type = config.model_type - if origin_model_type in ["chatglm", "qwen", "baichuan"]: - config.model_type = "qwen2" - user_model = TSModelCausalLMForITREX(user_model, config=config) - user_model.config.model_type = origin_model_type elif not (args.sq or args.mixed_precision): user_model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=args.trust_remote_code, _commit_hash=args._commit_hash, - use_neural_speed=False ) - if args.benchmark: user_model = user_model.eval() if hasattr(user_model, "eval") else user_model prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun." @@ -257,7 +209,7 @@ # start total_time = 0.0 - num_iter = args.iters + num_iter = args.benchmark_iters num_warmup = args.num_warmup total_token_num = 0 eos_token_id = tokenizer.eos_token_id @@ -267,7 +219,7 @@ # for chatglm2 only if hasattr(tokenizer, "build_chat_input"): input_ids = tokenizer.build_chat_input(prompt)["input_ids"] - input_ids = input_ids.repeat(args.batch_size, 1) + input_ids = input_ids.repeat(args.benchmark_batch_size, 1) eos_token_id = [ tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), @@ -277,11 +229,11 @@ elif hasattr(tokenizer, "build_prompt"): build_prompt = tokenizer.build_prompt(prompt) input_ids = tokenizer( - [build_prompt] * args.batch_size, return_tensors="pt" + [build_prompt] * args.benchmark_batch_size, return_tensors="pt" ).input_ids else: input_ids = tokenizer( - [prompt] * args.batch_size, return_tensors="pt" + [prompt] * args.benchmark_batch_size, return_tensors="pt" ).input_ids gen_ids = user_model.generate( input_ids, @@ -307,18 +259,32 @@ if args.accuracy: - args.model = (peft_config.base_model_name_or_path if args.peft_model_id else args.model) + args.model = ( + peft_config.base_model_name_or_path if args.peft_model_id else args.model + ) + + from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import ( + LMEvalParser, + evaluate, + ) - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser - args = LMEvalParser(model = "hf", - tokenizer = tokenizer, - user_model = user_model, - tasks = args.tasks, - device = "cpu", - batch_size = args.batch_size) + args = LMEvalParser( + model="hf", + tokenizer=tokenizer, + user_model=user_model, + tasks=args.tasks, + device="cpu", + batch_size=args.eval_batch_size, + ) results = evaluate(args) for task_name in args.tasks.split(","): if task_name == "wikitext": - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity,none"])) + print( + "Accuracy for %s is: %s" + % (task_name, results["results"][task_name]["word_perplexity,none"]) + ) else: - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc,none"])) + print( + "Accuracy for %s is: %s" + % (task_name, results["results"][task_name]["acc,none"]) + ) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh index 7dfa912f90e..c89f84e8f59 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh @@ -241,8 +241,8 @@ function run_tuning { script="run_generation_sq.py" elif [ "${topology}" = "llama2_7b_gptq" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --woq --bits ${bits} --compute_dtype fp32 --scheme ${scheme} --calib_iters 100" - extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --max_input_length 2048 " + extra_cmd=$extra_cmd" --woq --bits ${bits} --compute_dtype fp32 --scheme ${scheme} --n_samples 100" + extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --seq_len 2048 " extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" extra_cmd=$extra_cmd" --weight_dtype ${weight_dtype}" @@ -250,7 +250,7 @@ function run_tuning { elif [ "${topology}" = "mistral_7b_autoround" ]; then model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" extra_cmd=$extra_cmd" --woq --bits ${bits} --compute_dtype fp32 --scheme ${scheme} " - extra_cmd=$extra_cmd" --woq_algo "AutoRound" --desc_act --group_size 128 --calib_len 2048 --calib_iters 100" + extra_cmd=$extra_cmd" --woq_algo "AutoRound" --desc_act --group_size 128 --seq_len 2048 --n_samples 100" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" extra_cmd=$extra_cmd" --weight_dtype ${weight_dtype}" @@ -265,8 +265,8 @@ function run_tuning { script="run_generation_cpu_woq.py" elif [ "${topology}" = "mistral_7b_gptq" ]; then model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" - extra_cmd=$extra_cmd" --woq --bits ${bits} --compute_dtype fp32 --scheme ${scheme} --calib_iters 100" - extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --max_input_length 2048 --group_size 128" + extra_cmd=$extra_cmd" --woq --bits ${bits} --compute_dtype fp32 --scheme ${scheme} --n_samples 100" + extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --seq_len 2048 --group_size 128" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" extra_cmd=$extra_cmd" --weight_dtype ${weight_dtype}" diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py index 75d6d236bf1..dfefe86c6f6 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py @@ -145,7 +145,7 @@ def train(): quantization_config = BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, - llm_int8_skip_modules=["mm_projector"], + modules_to_not_convert=["mm_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index fd187f138e9..d4db8a2ee5a 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -699,7 +699,8 @@ def load_model( assert ipex.__version__ >= "2.1.0+cpu", "Please use Intel Extension for PyTorch >=2.1.0+cpu." if re.search("falcon", model_name, re.IGNORECASE): assert transformers.__version__ <= "4.33.3", "Please pip install transformers==4.33.3" - from intel_extension_for_transformers.transformers.llm.evaluation.models import TSModelCausalLMForITREX + from intel_extension_for_transformers.transformers.llm.quantization.sq_utils import \ + TSModelCausalLMForITREX model = TSModelCausalLMForITREX.from_pretrained( model_name, file_name="best_model.pt" diff --git a/intel_extension_for_transformers/transformers/llm/evaluation/models.py b/intel_extension_for_transformers/transformers/llm/evaluation/models.py deleted file mode 100644 index 61b301a380a..00000000000 --- a/intel_extension_for_transformers/transformers/llm/evaluation/models.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import torch -import transformers -from typing import Optional, Tuple -from transformers.modeling_outputs import CausalLMOutputWithPast -from optimum.intel.generation.modeling import TSModelForCausalLM -from intel_extension_for_transformers.transformers.utils.utility import ( - generate_dummy_past_key_values_for_inference, - generate_dummy_past_key_values_for_opt_llm, - MODEL_TYPES_REQUIRING_POSITION_IDS, - IPEX_OPT_LLM_SUPPORTED, -) - - -class TSModelCausalLMForITREX(TSModelForCausalLM): - def _reorder_cache( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. - - This is required to match `past_key_values` with the correct beam_idx at every generation step. - """ - - if self.config.model_type == "chatglm": - return tuple( - tuple( - past_state.index_select(1, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - if len(past_key_values[0]) == 4: # discrete kv_cache - for layer_past in past_key_values: - layer_past[3][layer_past[0].size(-2) - 1] = beam_idx - return past_key_values - else: - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - past_key_values = past_key_values or kwargs.get("past", None) - - if self.use_cache and past_key_values is not None: - if not ( - self.config.model_type == "chatglm" - and re.search("THUDM/chatglm-6b", self.config.auto_map["AutoConfig"]) - ): - input_ids = input_ids[:, -1:] - - # `past_key_values` may be in the standard format (e.g. in contrastive search), - # converts to bloom's format if needed - if past_key_values is not None and self.config.model_type == "bloom": - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) - position_ids = kwargs.get("position_ids", None) - - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - if self.config.model_type == "chatglm" and re.search( - "THUDM/chatglm-6b", self.config.auto_map["AutoConfig"] - ): - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - batch_size, seq_length = input_ids.shape - device = input_ids.device - if past_key_values is None: - context_lengths = [ - seq.tolist().index(self.config.bos_token_id) for seq in input_ids - ] - position_ids = ( - torch.arange(seq_length, dtype=torch.long, device=device) - .unsqueeze(0) - .repeat(batch_size, 1) - ) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [ - torch.cat( - ( - torch.zeros( - context_length, dtype=torch.long, device=device - ), - torch.arange( - seq_length - context_length, - dtype=torch.long, - device=device, - ) - + 1, - ) - ) - for context_length in context_lengths - ] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] - position_ids = torch.tensor( - [ - [mask_position, seq_length - context_length] - for mask_position, context_length in zip( - mask_positions, context_lengths - ) - ], - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(-1) - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": self.use_cache, - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": None, - } - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - position_ids: Optional[torch.FloatTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - model_type = self.config.model_type.replace("_", "-") - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - input_bs, input_len = input_ids.shape - if self.use_cache and past_key_values is None: - if model_type in IPEX_OPT_LLM_SUPPORTED: - past_key_values = generate_dummy_past_key_values_for_opt_llm( - config=self.config, input_bs=input_bs, num_beams=1 - ) - else: - past_key_values = generate_dummy_past_key_values_for_inference( - config=self.config, input_bs=input_bs - ) - inputs["past_key_values"] = past_key_values - if attention_mask is None: - inputs["attention_mask"] = torch.ones_like(input_ids) - if model_type == "chatglm": - if re.search("THUDM/chatglm-6b", self.config.auto_map["AutoConfig"]): - position_ids = self.prepare_inputs_for_generation(input_ids)[ - "position_ids" - ] - - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - if position_ids is not None: - inputs["position_ids"] = position_ids - else: - inputs["position_ids"] = torch.arange(input_len).repeat(input_bs, 1) - outputs = self.model(**inputs) - - if isinstance(outputs, (list, tuple)): - logits = outputs[0] - past_key_values = outputs[1] if self.use_cache else None - else: - logits = outputs["logits"] - past_key_values = outputs["past_key_values"] if self.use_cache else None - return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/autograd/functions.py b/intel_extension_for_transformers/transformers/llm/quantization/autograd/functions.py index 68f17dca5cc..483e27da94b 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/autograd/functions.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/autograd/functions.py @@ -39,17 +39,17 @@ class qbits_acquire_type(Enum): def qbits_woq_linear_ref_impl(activation, packw, bias, compute_type, weight_type, scale_type): - assert (activation.is_contiguous()) - assert (packw.is_contiguous()) activation = activation.to(torch.float32) n = qbits.acquire_packed_weight_info( packw, qbits_acquire_type.N.value)[0].item() k = activation.shape[1] revert_wei = torch.empty(k, n, dtype=torch.float) qbits.dequantize_packed_weight( - packw, revert_wei, False, compute_type, weight_type, scale_type) + packw, revert_wei, False, compute_type, weight_type, "fp32") + enable_act_shuffle = qbits.acquire_packed_weight_info( packw, qbits_acquire_type.ACT_SHUFFLE.value)[0] != 0 + if enable_act_shuffle: g_idx = qbits.acquire_packed_weight_info( packw, qbits_acquire_type.G_IDX.value) @@ -59,6 +59,7 @@ def qbits_woq_linear_ref_impl(activation, packw, bias, compute_type, weight_typ assert (bias.is_contiguous()) assert (bias.dtype == torch.float32) out += bias + return out @@ -117,6 +118,7 @@ def forward( False if scheme == "sym" else True, ) else: + out = qbits_woq_linear_ref_impl( A, B.data, bias, compute_dtype, weight_dtype, scale_dtype) output = out diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index b306cb01993..51eccf739dd 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -17,7 +17,6 @@ import os import torch -from ..utils import DTYPE_BITS_MAPPING from functools import reduce from operator import mul from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING, PeftType @@ -26,6 +25,7 @@ from intel_extension_for_transformers.transformers.llm.quantization.autograd import ( matmul_kbit, ) import intel_extension_for_transformers.qbits as qbits # pylint: disable=E0611, E0401 +from neural_compressor.torch.algorithms.weight_only.utility import quant_tensor as quant_nf4_fp4 class DropoutQBits_(torch.autograd.Function): @@ -99,6 +99,7 @@ def __init__( compute_dtype="fp32", compress_statistics=True, weight_dtype="int4_clip", + bits=4, scale_dtype="fp32", blocksize=32, scheme="sym", @@ -115,7 +116,7 @@ def __init__( self.blocksize = blocksize self.scheme = scheme self.weight_dtype = weight_dtype - self.bits = DTYPE_BITS_MAPPING[weight_dtype] + self.bits = bits self.scale_dtype = scale_dtype self.double_quant_scale_dtype = double_quant_scale_dtype self.compression_dim = compression_dim @@ -221,10 +222,14 @@ def set_weights_bias( g_idx = torch.empty(0, dtype=torch.int32) else: g_idx = torch.empty(0, dtype=torch.int32) - if q_config.bits == 4: + if q_config.bits == 4 and 'f' not in q_config.weight_dtype: int_weight = (int_weight - 8) * 16 // 16 gptq_zeros = (gptq_zeros - 8) * 16 // 16 + if q_config.weight_dtype in ["nf4", "fp4", "fp4_e2m1"]: + int_weight = torch.where(int_weight < 0, int_weight + 16, int_weight) + int_weight = int_weight.t_() + gptq_scales = gptq_scales.t_() if q_config.sym: gptq_zeros = torch.empty(0, dtype=torch.int8) @@ -329,7 +334,7 @@ def recover_int_weight(g_idx, int_weight): g_idx = None weight_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 6) weight_dtype = "".join(chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist()) - bits = 4 if weight_dtype in ["nf4", "int4_clip", "fp4_e2m1", "fp4_e2m1_bnb"] else 8 + bits = 4 if weight_dtype in ["nf4", "int4_clip", "fp4_e2m1"] else 8 compute_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 7) compute_dtype = "".join(chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist()) scales_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 8) @@ -352,12 +357,19 @@ def recover_int_weight(g_idx, int_weight): qbits.dequantize_packed_weight(self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype) - int_weight = self.quant_weight_w_scale( - revert_wei.t(), - scales.t(), - qzeros.to(torch.uint8).t() if qzeros is not None else None, - group_size=group_size, - ) + if weight_dtype in ["nf4", "fp4", "fp4_e2m1"]: + int_weight = quant_nf4_fp4(revert_wei.t(), + bits=bits, + group_size=group_size, + dtype=weight_dtype, + return_int=True)[0] + else: + int_weight = self.quant_weight_w_scale( + revert_wei.t(), + scales.t(), + qzeros.to(torch.uint8).t() if qzeros is not None else None, + group_size=group_size, + ) if g_idx is not None: int_weight = recover_int_weight(g_idx, int_weight.t()) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py b/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py new file mode 100644 index 00000000000..7d5e52c21ee --- /dev/null +++ b/intel_extension_for_transformers/transformers/llm/quantization/sq_utils.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import transformers +from datasets import load_dataset +from torch.nn.functional import pad +from torch.utils.data import DataLoader +from transformers.modeling_outputs import CausalLMOutputWithPast + +from intel_extension_for_transformers.tools.utils import is_ipex_available + +from ...utils import LazyImport, logger + +if is_ipex_available(): + import intel_extension_for_pytorch as ipex +torch = LazyImport("torch") + +IPEX_OPT_LLM_SUPPORTED_DICT = { + "2.2": ["gptj", "opt", "llama", "falcon", "chatglm", "baichuan", "gpt-neox"], + "2.3": [ + "gptj", + "opt", + "llama", + "falcon", + "chatglm", + "baichuan", + "qwen", + "bloom", + "codegen", + "gptbigcode", + "t5", + "mixtral", + "mpt", + ], +} +if is_ipex_available() and ipex.__version__ == "2.2.0+cpu": + logger.info( + "ipex.llm.optimize by 2.2.0 version supported model family: {}".format( + ",".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.2"]) + ) + ) + logger.info( + "The recommended transformers version is 4.35.2 if you used IPEX 2.2.0 version." + ) + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.2"] +elif is_ipex_available() and ipex.__version__ == "2.3.0+cpu": + logger.info( + "ipex.llm.optimize by 2.3.0 version supported model family: {}".format( + ", ".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.3"]) + ) + ) + logger.info( + "The recommended transformers version is 4.38.1 if you used IPEX 2.3.0 version." + ) + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] +else: + logger.warning("Please check the intel_extension_for_pytorch version is 2.3.0+cpu.") + IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] + +MODEL_TYPES_REQUIRING_POSITION_IDS = { + "codegen", + "gpt2", + "gpt-bigcode", + "gpt-neo", + "gpt-neox", + "gptj", + "imagegpt", + "llama", + "mistral", + "chatglm", +} + + +def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): + """Generate the dummy past_key_values.""" + from optimum.utils import NormalizedConfigManager + + if config.model_type == "qwen": + new_shape = [ + input_bs, + 1, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_hidden_layers + elif config.model_type == "baichuan": + new_shape = [ + input_bs, + config.num_attention_heads, + 1, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_hidden_layers + elif config.model_type == "chatglm": + new_shape = [ + 1, + input_bs, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_layers + else: + normalized_config = NormalizedConfigManager.get_normalized_config_class( + config.model_type + )(config) + num_layers = normalized_config.num_layers + num_attention_heads = normalized_config.num_attention_heads + hidden_size = normalized_config.hidden_size + d_k = hidden_size // num_attention_heads + num_key_value_heads = num_attention_heads + nb_pkv = 2 + if hasattr(normalized_config, "num_key_value_heads"): + num_key_value_heads = normalized_config.num_key_value_heads + if hasattr(normalized_config, "multi_query_group_num"): + num_key_value_heads = normalized_config.multi_query_group_num + if config.model_type == "bloom": + for nb_pkv in range(nb_pkv): + if nb_pkv % 2 == 0: + new_shape = [input_bs * num_key_value_heads, d_k, 1] + else: + new_shape = [input_bs * num_key_value_heads, 1, d_k] + else: + new_shape = [input_bs, num_key_value_heads, 1, d_k] + + beam_idx_tmp = torch.zeros( + (2048, int(input_bs * num_beams)), dtype=torch.long + ).contiguous() + past_key_values = [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros(size=new_shape).contiguous(), + torch.zeros(size=new_shape).contiguous(), + beam_idx_tmp, + ) + for _ in range(num_layers) + ] + return tuple(past_key_values) + + +def generate_dummy_past_key_values(config, input_bs): + """Generate the dummy past_key_values.""" + from optimum.utils import NormalizedConfigManager + + if config.model_type == "qwen": + new_shape = [ + input_bs, + 1, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_hidden_layers + elif config.model_type == "baichuan": + new_shape = [ + input_bs, + config.num_attention_heads, + 1, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_hidden_layers + elif config.model_type == "chatglm": + new_shape = [ + 1, + input_bs, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ] + num_layers = config.num_layers + else: + normalized_config = NormalizedConfigManager.get_normalized_config_class( + config.model_type + )(config) + nb_pkv = 2 + num_layers = normalized_config.num_layers + num_attention_heads = normalized_config.num_attention_heads + hidden_size = normalized_config.hidden_size + d_k = hidden_size // num_attention_heads + num_key_value_heads = num_attention_heads + if hasattr(normalized_config, "num_key_value_heads"): + num_key_value_heads = normalized_config.num_key_value_heads + if hasattr(normalized_config, "multi_query_group_num"): + num_key_value_heads = normalized_config.multi_query_group_num + + if config.model_type == "bloom": + shape_key = (input_bs * num_attention_heads, d_k, 1) + shape_value = (input_bs * num_attention_heads, 1, d_k) + key = torch.ones(size=shape_key) + value = torch.ones(size=shape_value) + past_key_values = tuple( + tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) + for _ in range(num_layers) + ) + return past_key_values + elif config.model_type == "gpt_bigcode": + new_shape = [input_bs, 0, d_k * 2] + dummy_tensor = torch.zeros(size=new_shape) + past_key_values = tuple([dummy_tensor] * num_layers) + return past_key_values + elif config.model_type == "falcon": + new_shape = [input_bs, 1, 0, d_k] + else: + new_shape = [input_bs, num_key_value_heads, 0, d_k] + past_key_values = [ + ( + torch.zeros(size=new_shape).contiguous(), + torch.zeros(size=new_shape).contiguous(), + ) + for _ in range(num_layers) + ] + return tuple(past_key_values) + + +def get_dataloader( + model_type, + quantization_config, + past_key_values, +): + shuffle=quantization_config.shuffle + padding=quantization_config.padding + seq_len=quantization_config.seq_len + + calib_dataset = load_dataset( + quantization_config.dataset, + split=( + "test" + if quantization_config.dataset in ["mbpp", "openai_humaneval"] + else "train" + ), + ) + if shuffle: + calib_dataset = calib_dataset.shuffle(seed=42) + + def tokenize_function(examples): + if "code" in examples: + example = quantization_config.tokenizer(examples["code"]) + elif "prompt" in examples: + example = quantization_config.tokenizer(examples["prompt"]) + elif "text" in examples: + example = quantization_config.tokenizer(examples["text"]) + else: + logger.error( + "Please check dataset prompt identifier," + + " NeelNanda/pile-10k is default used calibration dataset." + ) + exit(0) + return example + + def collate_batch(batch): + position_ids_padded = [] + input_ids_padded = [] + last_ind = [] + attention_mask_padded = [] + for text in batch: + input_ids = text["input_ids"] + if not padding: + input_ids = ( + input_ids[: int(seq_len)] + if len(input_ids) > int(seq_len) + else input_ids + ) # no_padding + else: + pad_len = seq_len - input_ids.shape[0] + input_ids = pad(input_ids, (0, pad_len), value=seq_len) + + last_ind.append(input_ids.shape[0] - 1) + attention_mask = torch.ones(len(input_ids)) + position_ids = torch.arange(len(input_ids)) + input_ids_padded.append(input_ids) + attention_mask_padded.append(attention_mask) + position_ids_padded.append(position_ids) + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + return ( + { + "input_ids": torch.vstack(input_ids_padded), + "attention_mask": torch.vstack(attention_mask_padded), + "position_ids": torch.vstack(position_ids_padded), + "past_key_values": past_key_values, + }, + torch.tensor(last_ind), + ) + else: + return ( + { + "input_ids": torch.vstack(input_ids_padded), + "attention_mask": torch.vstack(attention_mask_padded), + "past_key_values": past_key_values, + }, + torch.tensor(last_ind), + ) + + tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type="torch", columns=["input_ids"]) + + calib_dataloader = DataLoader( + tokenized_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_batch, + ) + return calib_dataloader + +from optimum.intel.generation.modeling import TSModelForCausalLM +class TSModelCausalLMForITREX(TSModelForCausalLM): + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + + if self.config.model_type == "chatglm": + return tuple( + tuple( + past_state.index_select(1, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past_key_values + ) + if len(past_key_values[0]) == 4: # discrete kv_cache + for layer_past in past_key_values: + layer_past[3][layer_past[0].size(-2) - 1] = beam_idx + return past_key_values + else: + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past_key_values + ) + + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + past_key_values = past_key_values or kwargs.get("past", None) + + if self.use_cache and past_key_values is not None: + input_ids = input_ids[:, -1:] + + # `past_key_values` may be in the standard format (e.g. in contrastive search), + # converts to bloom's format if needed + if past_key_values is not None and self.config.model_type == "bloom": + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + position_ids = kwargs.get("position_ids", None) + + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": self.use_cache, + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": None, + } + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + position_ids: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + model_type = self.config.model_type.replace("_", "-") + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + input_bs, input_len = input_ids.shape + + if self.use_cache and past_key_values is None: + if model_type in IPEX_OPT_LLM_SUPPORTED: + past_key_values = generate_dummy_past_key_values_for_opt_llm( + config=self.config, input_bs=input_bs, num_beams=1 + ) + else: + past_key_values = generate_dummy_past_key_values( + config=self.config, input_bs=input_bs + ) + inputs["past_key_values"] = past_key_values + if attention_mask is None: + inputs["attention_mask"] = torch.ones_like(input_ids) + + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + if position_ids is not None: + inputs["position_ids"] = position_ids + else: + inputs["position_ids"] = torch.arange(input_len).repeat(input_bs, 1) + outputs = self.model(**inputs) + + if isinstance(outputs, (list, tuple)): + logits = outputs[0] + past_key_values = outputs[1] if self.use_cache else None + else: + logits = outputs["logits"] + past_key_values = outputs["past_key_values"] if self.use_cache else None + return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + +def loading_configure_file(model, json_file_path, example_inputs): + """Recover ipex model from JSON file. + + Args: + model (object): fp32 model need to do quantization. + json_file_path (json): configuration JSON file for ipex. + example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function. + + Returns: + (object): quantized model + """ + + ipex = LazyImport("intel_extension_for_pytorch") + from torch.ao.quantization.observer import MinMaxObserver + + if ipex.__version__ >= "2.1.100": + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) + else: + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver()) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True) + else: + model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True) + model.load_qconf_summary(qconf_summary=json_file_path) + model = ipex.quantization.convert(model, inplace=True) + model.eval() + with torch.no_grad(): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) + model = torch.jit.freeze(model.eval()) + + model(**example_inputs) + model(**example_inputs) + return model + +def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remote_code=False): + """Recover ipex model from JSON file. + + Args: + model (object): fp32 model need to do quantization. + json_file_path (json): configuration JSON file for ipex saved. + trust_remote_code (bool): trust remote code. + + Returns: + (object): quantized model + """ + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path, trust_remote_code=trust_remote_code) + if model.config.model_type in IPEX_OPT_LLM_SUPPORTED: + qconfig = ipex.quantization.default_static_qconfig_mapping + model = ipex.llm.optimize( + model.eval(), + dtype=torch.float, + inplace=True, + quantization_config=qconfig, + deployment_mode=False, + ) + # config + model.config.torchscript = True + config = model.config + + # example_inputs + + input_ids= model.dummy_inputs["input_ids"] + input_bs, input_len = input_ids.shape + attention_mask = torch.ones_like(input_ids) + position_ids = torch.arange(input_len).repeat(input_bs, 1) + num_beams = 1 + if config.model_type in IPEX_OPT_LLM_SUPPORTED: + past_key_values = generate_dummy_past_key_values_for_opt_llm( + config=config, input_bs=input_bs, num_beams=num_beams + ) + else: + past_key_values = generate_dummy_past_key_values( + config=config, input_bs=input_bs + ) + if config.model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + example_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values + } + else: + example_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values + } + + model = loading_configure_file(model, json_file_path, example_inputs) + model = TSModelCausalLMForITREX(model, config=config) + return model diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index f912135db1a..0678c2eb72e 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -16,26 +16,48 @@ # limitations under the License. -import logging import gc +import logging import math import os from ....tools.utils import _ipex_version from accelerate import init_empty_weights -from neural_compressor import quantization -from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from datasets import load_dataset +from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear +from neural_compressor.torch.quantization import ( + AutoRoundConfig, + AWQConfig, + GPTQConfig, + RTNConfig, + SmoothQuantConfig, + TEQConfig, + convert, + prepare, + quantize, +) from neural_compressor.utils.utility import LazyImport -from neural_compressor.config import PostTrainingQuantConfig +from transformers import AutoTokenizer + from intel_extension_for_transformers.tools.utils import ( - is_ipex_available, is_autoround_available, + is_ipex_available, ) +from ...utils import CpuInfo + if is_ipex_available(): import intel_extension_for_pytorch as ipex + from .sq_utils import ( + IPEX_OPT_LLM_SUPPORTED, + MODEL_TYPES_REQUIRING_POSITION_IDS, + generate_dummy_past_key_values, + generate_dummy_past_key_values_for_opt_llm, + get_dataloader, + ) if is_autoround_available(): from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear as auto_round_woqlinear # pylint: disable=E0401 + from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader as get_autoround_dataloader torch = LazyImport("torch") @@ -46,7 +68,6 @@ DTYPE_BITS_MAPPING = { "nf4": 4, "fp4": 4, # fp4 == fp4_e2m1 - "fp4_e2m1_bnb": 4, "fp4_e2m1": 4, "int4": 4, "int4_fullrange": 4, @@ -62,27 +83,30 @@ def unpack_weight(qweight, scales, qzeros, q_config): sym = q_config.sym bits = q_config.bits wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) + if qzeros is not None: + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + if bits == 8: + zeros = zeros.to(torch.int8 if sym else torch.uint8) + # due to INC minus one + zeros = zeros + 1 + try: + zeros = zeros.reshape(scales.shape) + except: + # zeros and scales have different iteam numbers. + # remove 1 (due to 0 + 1 in line 68) + zeros = zeros[zeros != 1] + zeros = zeros.reshape(scales.shape) - zeros = torch.bitwise_right_shift( - torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) - ).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - if bits == 8: - zeros = zeros.to(torch.int8 if sym else torch.uint8) - # due to INC minus one - zeros = zeros + 1 - try: - zeros = zeros.reshape(scales.shape) - except: - # zeros and scales have different iteam numbers. - # remove 1 (due to 0 + 1 in line 68) - zeros = zeros[zeros != 1] - zeros = zeros.reshape(scales.shape) - - # due to INC asym return torch.uint8 but backend request int8, - # change it to int8 with offset 128 - if not sym and bits == 8: - zeros = (zeros.to(torch.int32) - 128).to(torch.int8) + # due to INC asym return torch.uint8 but backend request int8, + # change it to int8 with offset 128 + if not sym and bits == 8: + zeros = (zeros.to(torch.int32) - 128).to(torch.int8) + zeros = zeros.contiguous() + else: + zeros = None weight = torch.bitwise_right_shift( torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1) @@ -98,7 +122,7 @@ def unpack_weight(qweight, scales, qzeros, q_config): # change it to int8 with offset 128 if not sym: weight = (weight.to(torch.int32) - 128).to(torch.int8) - return weight, scales, zeros + return weight.contiguous(), scales.contiguous(), zeros def replace_linear( @@ -158,9 +182,6 @@ def _replace_linear( quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", - "fp4_e2m1_bnb", - "fp4_e2m1", - "nf4", "int4_fullrange", ] @@ -186,48 +207,60 @@ def _replace_linear( or device == "auto" ): if is_ipex_available() and quantization_config.use_ipex: - from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear - from intel_extension_for_pytorch.utils.weight_only_quantization import \ - _convert_optimum_format_to_desired + from intel_extension_for_pytorch.nn.modules import ( + WeightOnlyQuantizedLinear as ipex_linear, + ) + from intel_extension_for_pytorch.utils.weight_only_quantization import ( + _convert_optimum_format_to_desired, + ) - qweight, scales, qzeros = _convert_optimum_format_to_desired(module.qweight, - module.scales, - module.qzeros) + qweight, scales, qzeros = ( + _convert_optimum_format_to_desired( + module.qweight, module.scales, module.qzeros + ) + ) weight_dtype = { 4: ipex.quantization.WoqWeightDtype.INT4, 8: ipex.quantization.WoqWeightDtype.INT8, } compute_dtype = { - "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. "bf16": ipex.quantization.WoqLowpMode.BF16, "fp16": ipex.quantization.WoqLowpMode.FP16, "int8": ipex.quantization.WoqLowpMode.INT8, - } - ipex_qconfig_mapping = ( - ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype[quantization_config.bits], - lowp_mode=compute_dtype[quantization_config.compute_dtype], - act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, - group_size=quantization_config.group_size, - ) + ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[ + quantization_config.compute_dtype + ], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, ) tmp_linear = torch.nn.Linear( in_features, out_features, - True if hasattr(module, "bias") else False - ) + True if hasattr(module, "bias") else False, + ) tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig - model._modules[name] = ipex_linear.from_float_and_int4_weight( - mod = tmp_linear, - qweight = qweight, - scales = scales, - zero_points = qzeros, - bias = module.bias if hasattr(module, "bias") else None, - group_size = quantization_config.group_size, - g_idx = module.g_idx if hasattr(module, "g_idx") else None, + model._modules[name] = ( + ipex_linear.from_float_and_int4_weight( + mod=tmp_linear, + qweight=qweight, + scales=scales, + zero_points=qzeros, + bias=( + module.bias if hasattr(module, "bias") else None + ), + group_size=quantization_config.group_size, + g_idx=( + module.g_idx + if hasattr(module, "g_idx") + else None + ), + ) ) else: from .nn.modules import ( @@ -238,11 +271,9 @@ def _replace_linear( quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", - "fp4_e2m1_bnb", - "fp4_e2m1", "nf4", + "fp4_e2m1", ] - model._modules[name] = QuantizedLinearQBits( in_features, out_features, @@ -250,17 +281,20 @@ def _replace_linear( compute_dtype=quantization_config.compute_dtype, compress_statistics=False, weight_dtype=quantization_config.weight_dtype, + bits=quantization_config.bits, scale_dtype=quantization_config.scale_dtype, blocksize=quantization_config.group_size, scheme=quantization_config.scheme, - compression_dtype=getattr(module, "compression_dtype", torch.int32), + compression_dtype=getattr( + module, "compression_dtype", torch.int32 + ), compression_dim=getattr(module, "compression_dim", 1), device=device, use_optimum_format=use_optimum_format, ) elif device == "xpu" or device == torch.device("xpu"): - from intel_extension_for_pytorch.nn.utils._quantize_convert \ - import WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401 + from intel_extension_for_pytorch.nn.utils._quantize_convert import \ + WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401 model._modules[name] = ipex_linear( in_features, out_features, @@ -279,7 +313,11 @@ def _replace_linear( False if _ipex_version < "2.3.10" else True), ) if quantization_config.quant_method.value == "gptq": - g_idx = getattr(module, "g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) + g_idx = getattr( + module, + "g_idx", + torch.zeros(in_features, dtype=torch.int32).to(device), + ) else: g_idx = None model._modules[name].set_scales_zps_gidx( @@ -327,26 +365,30 @@ def _replace_linear( model._modules[name].requires_grad_(False) if quantization_config.use_ipex: pass - elif (device == "cpu" or device == torch.device("cpu") or device == "auto"): + elif ( + device == "cpu" or device == torch.device("cpu") or device == "auto" + ): if quantization_config.weight_dtype in [ "fp8_e5m2", "fp8_e4m3", - "nf4", - "fp4_e2m1_bnb", - "fp4_e2m1", ]: model._modules[name].set_fp_weights_bias( module.weight.data, None if module.bias is None else module.bias.data, ) else: - int_weight, scales, zeros = unpack_weight( - module.qweight, - module.scales, - module.qzeros, - quantization_config, - ) - int_weight = int_weight.view(-1, int_weight.shape[-1]) + if quantization_config.weight_dtype in ["int4", "int4_clip", "int8"]: + int_weight, scales, zeros = unpack_weight( + module.qweight, + module.scales, + module.qzeros if hasattr(module, "qzeros") else None, + quantization_config, + ) + int_weight = int_weight.view(-1, int_weight.shape[-1]) + else: + int_weight = module.unpack_tensor_with_numpy(module.qweight) + scales = module.scales + zeros = module.qzeros if hasattr(module, "qzeros") else None model._modules[name].set_weights_bias( int_weight, @@ -391,6 +433,100 @@ def _replace_linear( return model, is_replaced +def default_run_fn( + model, tokenizer, dataset, max_length=512, n_samples=100, batch_size=8, algo="rtn" +): + from torch.utils.data import DataLoader + + if isinstance(dataset, (str, bytes, os.PathLike)): + calib_dataset = load_dataset(dataset, split="train") + calib_dataset = calib_dataset.shuffle(seed=42) + if tokenizer is None: + logger.error("Please provide the tokenizer in quantization_config.") + exit(0) + + def tokenize_function(examples): + if algo == "teq": + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if "prompt" in examples: + if algo == "teq": + example = tokenizer( + examples["prompt"], padding="max_length", max_length=max_length + ) + else: + example = tokenizer(examples["prompt"]) + elif "code" in examples: + if algo == "teq": + example = tokenizer( + examples["code"], padding="max_length", max_length=max_length + ) + else: + example = tokenizer(examples["code"]) + elif "text" in examples: + if algo == "teq": + example = tokenizer( + examples["text"], padding="max_length", max_length=max_length + ) + else: + example = tokenizer(examples["text"]) + else: + logger.error( + "Please check dataset prompt identifier," + + " NeelNanda/pile-10k is default used calibration dataset." + ) + exit(0) + return example + + tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type="torch", columns=["input_ids"]) + tokenized_dataset = tokenized_dataset.filter(lambda x: x["input_ids"].shape[-1] >= max_length) + + def collate_batch(batch): + input_ids_padded = [] + for text in batch: + input_ids = text["input_ids"] + if len(input_ids) >= max_length: + input_ids = input_ids[:max_length] + input_ids_padded.append(input_ids) + else: + continue + assert input_ids_padded != [], \ + "The dataset does not have data that meets the required input length. Please reduce seq_len." + return torch.vstack(input_ids_padded) + + + calib_dataloader = DataLoader( + tokenized_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_batch, + ) + total_cnt = 0 + for i, (input_ids) in enumerate(calib_dataloader): + if total_cnt + input_ids.shape[0] > n_samples: + input_ids = input_ids[: n_samples - total_cnt, ...] + total_cnt += input_ids.shape[0] + if total_cnt >= n_samples: + break + + try: + model( + input_ids=input_ids, + ) + except ValueError: + pass + +@torch.no_grad() +def run_fn_for_autoround(model, dataloader): + for data in dataloader: + if isinstance(data, tuple) or isinstance(data, list): + model(*data) + elif isinstance(data, dict): + model(**data) + else: + model(data) + def convert_to_quantized_model(model, config, device="cpu"): if device == "xpu" or device == torch.device("xpu"): import intel_extension_for_pytorch @@ -398,101 +534,16 @@ def convert_to_quantized_model(model, config, device="cpu"): assert ( hasattr(torch, "xpu") and torch.xpu.is_available() ), "There is no xpu device in this system!" - config.post_init_xpu() - else: - config.post_init_cpu() - calib_dataloader = config.calib_dataloader - calib_func = config.calib_func - calib_iters = config.calib_iters - calib_dataset = config.dataset - model_device = next(model.parameters()).device - - if config.quant_method.value == "autoround": - from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader - calib_dataloader = get_dataloader(config.tokenizer, # pylint: disable=E1123 - seqlen=config.calib_len, - dataset_name=config.dataset, - seed=42, - bs=8, - n_samples=config.nsamples) - elif ( - calib_dataloader is None - and config.quant_method.value not in ["rtn"] - and calib_dataset is not None - ): - from datasets import load_dataset - from torch.utils.data import DataLoader - - if isinstance(calib_dataset, (str, bytes, os.PathLike)): - calib_dataset = load_dataset(calib_dataset, split="train") - calib_dataset = calib_dataset.shuffle(seed=42) - if config.tokenizer is None: - logger.error( - "Please provide the tokenizer or provide calib_func directly," - + " the following is how to get tokenizer. \n" - + " from transformer import AutoTokenizer \n" - + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" - ) - exit(0) - def tokenize_function(examples): - if "prompt" in examples: - example = config.tokenizer(examples["prompt"]) - elif "code" in examples: - example = config.tokenizer(examples["code"]) - elif "text" in examples: - example = config.tokenizer(examples["text"]) - else: - logger.error( - "Please check dataset prompt identifier," - + " NeelNanda/pile-10k is default used calibration dataset." - ) - exit(0) - return example - - tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) - tokenized_dataset.set_format(type="torch", columns=["input_ids"]) - - def collate_batch(batch): - input_ids_padded = [] - for text in batch: - input_ids = text["input_ids"] - input_ids = ( - input_ids[:512] - if (len(input_ids) > 512 and config.quant_method.value != "gptq") - else input_ids - ) - input_ids_padded.append(input_ids) - return torch.vstack(input_ids_padded) - - calib_dataloader = DataLoader( - tokenized_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_batch, - ) - if calib_func is None and config.quant_method.value == "awq": - - def default_calib_func(model): - """This is the default calibration function, the dataset is NeelNanda/pile-10k, - the default calib_iters is 100.""" - for i, (input_ids) in enumerate(calib_dataloader): - if i >= calib_iters: - break - model( - input_ids=input_ids, - ) - - calib_func = default_calib_func - logger.info( - "The default calibration function is used, " - + "the calibration dataset is NeelNanda/pile-10k," - + "batchsize is 1 and calibration iteration is 100." - ) + orig_dtype = torch.float32 + for param in model.parameters(): + orig_dtype = param.dtype + if orig_dtype != torch.float32: + model.to(dtype=torch.float32) + break if config.weight_dtype in ["fp8_e4m3", "fp8_e5m2"]: return replace_linear(model, None, None, config, device=device) else: - bits = DTYPE_BITS_MAPPING[config.weight_dtype] if config.weight_dtype == "int8": dtype = "int8" elif "int4" in config.weight_dtype: @@ -501,125 +552,146 @@ def default_calib_func(model): dtype = config.weight_dtype # mapping to INC config if config.quant_method.value == "rtn": - recipes = { - "layer_wise_quant": config.layer_wise, - "rtn_args": { - "enable_full_range": ( - True if "fullrange" in config.weight_dtype else False - ), - "enable_mse_search": config.mse_range, - }, - } - algorithm = "RTN" + quant_config = RTNConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + use_layer_wise=config.layer_wise, + ) + quant_config.set_local(".*lm_head", RTNConfig(dtype="fp32")) + quant_config.set_local(".*output_layer", RTNConfig(dtype="fp32")) + quant_config.set_local(".*embed_out", RTNConfig(dtype="fp32")) + model = prepare(model, quant_config) + model = convert(model) elif config.quant_method.value == "awq": - recipes = { - "rtn_args": { - "enable_full_range": ( - True if "fullrange" in config.weight_dtype else False - ), - "enable_mse_search": config.mse_range, - }, - "awq_args": {"folding": True}, - } - algorithm = "AWQ" + quant_config = AWQConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + use_layer_wise=config.layer_wise, + use_auto_scale=config.auto_scale, + use_auto_clip=config.auto_clip, + folding=True, + ) + quant_config.set_local(".*lm_head", AWQConfig(dtype="fp32")) + quant_config.set_local(".*output_layer", AWQConfig(dtype="fp32")) + quant_config.set_local(".*embed_out", AWQConfig(dtype="fp32")) + logger.info(f"Do AWQ algorithm with config {quant_config}") + run_fn = default_run_fn + run_args = ( + config.tokenizer, + config.dataset, + config.seq_len, # max_length + config.n_samples, # n_samples + config.batch_size, # batch_size + config.quant_method.value, # algo + ) + example_inputs = torch.ones([1, 512], dtype=torch.long).to(device) + model = prepare(model=model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(model, *run_args) + model = convert(model) elif config.quant_method.value == "teq": - recipes = {"teq_args": {}} - algorithm = "TEQ" + quant_config = TEQConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + use_layer_wise=config.layer_wise, + absorb_to_layer=config.absorb_to_layer + ) + quant_config.set_local(".*lm_head", TEQConfig(dtype="fp32")) + quant_config.set_local(".*output_layer", TEQConfig(dtype="fp32")) + quant_config.set_local(".*embed_out", TEQConfig(dtype="fp32")) + logger.info(f"Do TEQ algorithm with config {quant_config}") + run_fn = default_run_fn + run_args = ( + config.tokenizer, + config.dataset, + config.seq_len, # max_length + config.n_samples, # n_samples + config.batch_size, # batch_size + config.quant_method.value, # algo + ) + example_inputs = torch.ones([1, 512], dtype=torch.long).to(device) + model = prepare(model=model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(model, *run_args) + model = convert(model) + elif config.quant_method.value == "gptq": - recipes = { - "layer_wise_quant": config.layer_wise, - "gptq_args": { - "act_order": config.desc_act, - "percdamp": config.damp_percent, - "block_size": config.blocksize, - "nsamples": config.nsamples, - "use_max_length": True if config.max_input_length else False, - "pad_max_length": config.max_input_length, - "static_groups": config.static_groups, - "true_sequential": config.true_sequential, - }, - } - algorithm = "GPTQ" + model.seqlen = config.seq_len + quant_config = GPTQConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + use_layer_wise=config.layer_wise, + act_order=config.desc_act, + percdamp=config.damp_percent, + block_size=config.blocksize, + static_groups=config.static_groups, + ) + quant_config.set_local(".*lm_head", GPTQConfig(dtype="fp32")) + quant_config.set_local(".*output_layer", GPTQConfig(dtype="fp32")) + quant_config.set_local(".*embed_out", GPTQConfig(dtype="fp32")) + logger.info(f"Do GPTQ algorithm with config {quant_config}") + run_fn = default_run_fn + run_args = ( + config.tokenizer, + config.dataset, + config.seq_len, # max_length + config.n_samples, # n_samples + config.batch_size, # batch_size + config.quant_method.value, # algo + ) + model = prepare(model=model, quant_config=quant_config) + run_fn(model, *run_args) + model = convert(model) elif config.quant_method.value == "autoround": - recipes = { - "autoround_args": { - "n_samples": config.nsamples, - "seqlen": config.calib_len, - "iters": config.calib_iters, - "scale_dtype": config.scale_dtype, - "enable_quanted_input": not config.disable_quanted_input, - "lr": config.lr, - "minmax_lr": config.minmax_lr, - } - } - algorithm = "AUTOROUND" + quant_config = AutoRoundConfig( + dtype=dtype, + bits=config.bits, + use_sym=config.sym, + group_size=config.group_size, + enable_quanted_input=not config.disable_quanted_input, + lr=config.lr, + minmax_lr=config.minmax_lr, + seqlen=config.seq_len, + n_samples=config.n_samples, + iters=config.iters, + scale_dtype=config.scale_dtype, + ) + if config.quant_lm_head is False: + quant_config.set_local(".*lm_head", AutoRoundConfig(dtype="fp32")) + quant_config.set_local(".*output_layer", AutoRoundConfig(dtype="fp32")) + quant_config.set_local(".*embed_out", AutoRoundConfig(dtype="fp32")) + logger.info(f"Do AutoRound algorithm with config {quant_config}") + dataloader = get_autoround_dataloader(tokenizer=config.tokenizer, + seqlen=config.seq_len, + dataset_name="NeelNanda/pile-10k", + seed=42, + bs=config.batch_size, + n_samples=config.n_samples) + run_fn = run_fn_for_autoround + run_args = (dataloader,) + model = prepare(model=model, quant_config=quant_config) + run_fn(model, *run_args) + model = convert(model) else: assert False, "The Supported algorithm are RTN, AWQ, TEQ, GPTQ, AUTOROUND" - conf = PostTrainingQuantConfig( - approach="weight_only", - op_type_dict={ - ".*": { - "weight": { - "bits": bits, - "dtype": dtype, - "group_size": config.group_size, # -1 (per-channel) - "scheme": config.scheme, - "algorithm": algorithm, - }, - }, - }, - op_name_dict={ - ".*lm_head": { # re.match - "weight": {"dtype": "fp32"}, - }, - ".*output_layer": { # re.match - "weight": {"dtype": "fp32"}, - }, - ".*embed_out": { # re.match - "weight": {"dtype": "fp32"}, - }, - }, - recipes=recipes, - ) - # TEQ: set calib_func=None, use default training func as calib_func - # RTN: doesn't need calib_func - if config.quant_method.value not in ["awq"]: - calib_func = None - - orig_dtype = torch.float32 - for param in model.parameters(): - orig_dtype = param.dtype - if orig_dtype != torch.float32: - model.to(dtype=torch.float32) - break - inc_model = quantization.fit( - model, conf, calib_func=calib_func, calib_dataloader=calib_dataloader - ) - inc_model.eval() - if device == "xpu" or device == torch.device("xpu"): - model = inc_model.export_compressed_model( - compression_dtype=torch.int8, - compression_dim=0, - use_optimum_format=False, - scale_dtype=convert_dtype_str2torch(config.scale_dtype), - device="xpu", - ) if _ipex_version < "2.3.10" else inc_model.export_compressed_model(use_optimum_format=True, device="xpu") - - q_model = replace_linear(model, None, None, config, device=device) - else: - if config.weight_dtype not in ["nf4", "fp4_e2m1_bnb", "fp4_e2m1"]: - inc_model = inc_model.export_compressed_model(use_optimum_format=True) - inc_model.eval() - q_model = replace_linear(inc_model, None, None, config, device=device) - else: - q_model = replace_linear( - inc_model.model, None, None, config, device=device - ) + logger.warning("The recommended ipex version is higher than 2.3.10 for xpu device.") + + model.eval() + # INC attribute conflicted with transformers when use nf4/int8 training. + del model.is_quantized + q_model = replace_linear(model, None, None, config, device=device) if orig_dtype != torch.float32: q_model.to(dtype=orig_dtype) + return q_model.to(device) @@ -661,3 +733,99 @@ def get_bits(config): config.weight_dtype ) return bits + + +def convert_to_smoothquant_model(model, quantization_config): + model_type = model.config.model_type.replace("_", "-") + # ipex.optimize_transformers + if quantization_config.ipex_opt_llm is None: + if model_type in IPEX_OPT_LLM_SUPPORTED: + quantization_config.ipex_opt_llm = True + logger.info( + "quantization_config.ipex_opt_llm set to True and ipex.llm.optimize is used." + ) + else: + quantization_config.ipex_opt_llm = False + if quantization_config.ipex_opt_llm: + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) + model = ipex.llm.optimize( + model.eval(), + quantization_config=qconfig, + dtype=torch.float32, + inplace=True, + deployment_mode=False, + ) + model.eval() + # past_key_values + num_beams = quantization_config.num_beams + if quantization_config.ipex_opt_llm: + past_key_values = generate_dummy_past_key_values_for_opt_llm( + config=model.config, input_bs=1, num_beams=num_beams + ) + else: + past_key_values = generate_dummy_past_key_values( + config=model.config, input_bs=1 + ) + # get calibration dataloader + calib_dataloader = get_dataloader( + model_type, quantization_config, past_key_values=past_key_values + ) + + def calib_func(model): + with torch.no_grad(): + for i, (inputs, last_ind) in enumerate(calib_dataloader): + if i >= quantization_config.n_samples: + break + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + model( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + position_ids=inputs["position_ids"], + attention_mask=inputs["attention_mask"], + ) + else: + model( + input_ids=inputs["input_ids"], + past_key_values=inputs["past_key_values"], + attention_mask=inputs["attention_mask"], + ) + + # example_inputs + for i, (inputs, last_ind) in enumerate(calib_dataloader): + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + example_inputs = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "position_ids": inputs["position_ids"], + "past_key_values": inputs["past_key_values"], + } + else: + example_inputs = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "past_key_values": inputs["past_key_values"], + } + break + quant_config = SmoothQuantConfig( + alpha=quantization_config.alpha, + init_alpha=quantization_config.init_alpha, + alpha_min=quantization_config.alpha_min, + alpha_max=quantization_config.alpha_max, + alpha_step=quantization_config.alpha_step, + shared_criterion=quantization_config.shared_criterion, + do_blockwise=quantization_config.do_blockwise, + excluded_precisions=quantization_config.excluded_precisions, + ) + # fallback + if model_type in ["gptj", "gpt_neox", "mpt"]: + quant_config = quant_config.set_local( + torch.add, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32") + ) + model = quantize( + model, + quant_config=quant_config, + run_fn=calib_func, + example_inputs=example_inputs, + ) + + return model diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 1314e464eff..6251b308be2 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -55,10 +55,6 @@ ) from ..utils.utility import ( CpuInfo, - generate_dummy_past_key_values, - generate_dummy_past_key_values_for_opt_llm, - MODEL_TYPES_REQUIRING_POSITION_IDS, - IPEX_OPT_LLM_SUPPORTED, WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -68,12 +64,13 @@ convert_dtype_str2torch, convert_dtype_torch2str, convert_to_quantized_model, + convert_to_smoothquant_model, replace_linear, ) from ...tools.utils import is_intel_gpu_available, is_ipex_available from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear from neural_compressor.model.torch_model import PyTorchFXModel from threading import Thread from transformers.configuration_utils import PretrainedConfig @@ -125,27 +122,32 @@ def recover_export_model(model, current_key_name=None): zeros, int_weight, ) = module.recover_qparms() + dtype = "int4" if weight_dtype == "int4_clip" else weight_dtype + use_optimum_format = False if weight_dtype in ["fp4_e2m1", "fp4", "nf4"] else True model._modules[name] = WeightOnlyLinear( in_features, out_features, + dtype=dtype, bits=bits, - groupsize=groupsize, - dtype="int", + group_size=groupsize, zp=zp, bias=module.bias is not None, scale_dtype=scales_dtype, g_idx=desc_act, - use_optimum_format=True, + use_optimum_format=use_optimum_format, ) # Setting g_idx is invalid when use_optimum_format is True, so set it again when g_idx is not None. # https://github.com/intel/neural-compressor/blob/v2.5.dev2/neural_compressor/adaptor/torch_utils/ # model_wrapper.py#L343 model._modules[name].pack( - int_weight, scales, zeros, module.bias, g_idx=g_idx + int_weight.contiguous(), + scales.contiguous(), + zeros.contiguous() if zeros is not None else None, + module.bias.contiguous() if module.bias is not None else None, ) if g_idx is not None: - model._modules[name].g_idx = g_idx + model._modules[name].g_idx = g_idx.contiguous() if len(list(module.children())) > 0: # pylint: disable=E1101 _ = recover_export_model(module, current_key_name) @@ -156,23 +158,29 @@ def recover_export_model(model, current_key_name=None): def build_woq_model(model, quantization_config): from neural_compressor.adaptor.torch_utils.util import set_module - + weight_dtype = quantization_config.weight_dtype for n, m in model.named_modules(): if "lm_head" in n or "output_layer" in n or "embed_out" in n: continue if isinstance(m, torch.nn.Linear): - zp = getattr(quantization_config, "zero_point", not getattr(quantization_config, "sym", False)) + zp = getattr( + quantization_config, + "zero_point", + not getattr(quantization_config, "sym", False), + ) + dtype = "int4" if weight_dtype == "int4_clip" else weight_dtype + use_optimum_format = False if weight_dtype in ["nf4", "fp4", "fp4_e2m1"] else True with init_empty_weights(): new_module = WeightOnlyLinear( m.in_features, m.out_features, - quantization_config.bits, - quantization_config.group_size, - dtype="int", + dtype=dtype, + bits=quantization_config.bits, + group_size=quantization_config.group_size, zp=zp, bias=m.bias is not None, g_idx=True, - use_optimum_format=True, + use_optimum_format=use_optimum_format, ) set_module(model, n, new_module) return model @@ -192,18 +200,10 @@ def convert_model_to_public(model): elif model.quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", - "nf4", - "fp4_e2m1", - "fp4_e2m1_bnb", ]: model = recover_export_model(model) -def make_contiguous(model): - for param in model.parameters(): - if param.data.ndimension() > 1: - param.data = param.data.contiguous() - def save_low_bit( self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs @@ -223,7 +223,8 @@ def save_low_bit( self.model.config.quantization_config = self.quantization_config self.model.config.save_pretrained(save_directory) weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME + ) torch.save(self.quantized_state_dict(), weights_file) return @@ -231,31 +232,48 @@ def save_low_bit( os.makedirs(save_directory, exist_ok=True) # use transformers original `save_pretrained` function del self.save_pretrained - make_contiguous(self) + self.save_pretrained( save_directory=save_directory, push_to_hub=push_to_hub, **kwargs ) if self.quantization_config.use_ipex: + def save_linear_parameters(model, save_directory): # only can save to pytorch model.bin due to ipex. weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME + ) os.remove(weights_file) weights_file = os.path.join( - os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME + ) linear_parameters = {} - from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear + from intel_extension_for_pytorch.nn.modules import ( + WeightOnlyQuantizedLinear as ipex_cpu_linear, + ) + for name, module in model.named_modules(): if isinstance(module, ipex_cpu_linear): - linear_parameters[name + ".ipex_scales"] = module._op_context.get_scales().contiguous() - linear_parameters[name + ".ipex_weight"] = \ - module._op_context.to_public(module._op_context.get_weight()).contiguous() - linear_parameters[name + ".ipex_zeros"] = module._op_context.get_zero_points().contiguous() + linear_parameters[name + ".ipex_scales"] = ( + module._op_context.get_scales().contiguous() + ) + linear_parameters[name + ".ipex_weight"] = ( + module._op_context.to_public( + module._op_context.get_weight() + ).contiguous() + ) + linear_parameters[name + ".ipex_zeros"] = ( + module._op_context.get_zero_points().contiguous() + ) if module._op_context.get_bias() is not None: - linear_parameters[name + ".ipex_bias"] = module._op_context.get_bias().contiguous() + linear_parameters[name + ".ipex_bias"] = ( + module._op_context.get_bias().contiguous() + ) if module._op_context.get_g_idx() is not None: - linear_parameters[name + ".ipex_g_idx"] = module._op_context.get_g_idx().contiguous() + linear_parameters[name + ".ipex_g_idx"] = ( + module._op_context.get_g_idx().contiguous() + ) others_parameters = model.state_dict() linear_parameters.update(others_parameters) @@ -345,16 +363,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): use_vllm = kwargs.pop("use_vllm", None) if use_vllm is not None: logger.info("The backend is vLLM.") - from vllm import LLM # pylint: disable=E1101 + from vllm import LLM # pylint: disable=E1101 from vllm.model_executor.model_loader import get_model_loader # pylint: disable=E0611 - from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 - from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ColumnParallelLinear, - RowParallelLinear) # pylint: disable=E1101 + from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ColumnParallelLinear, + RowParallelLinear) # pylint: disable=E1101 os.environ["backend"] = "use_vllm" - llm = LLM(model=pretrained_model_name_or_path, trust_remote_code=True) # Create an vllm instance. + llm = LLM( + model=pretrained_model_name_or_path, trust_remote_code=True + ) # Create an vllm instance. model = llm.llm_engine.model_executor.driver_worker.model_runner.model # pylint: disable=E1101 print("Original model =", model) @@ -365,12 +386,22 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if "qkv_proj" in name or "gate_up_proj" in name: input_dim = getattr(params, "input_dim", None) output_dim = getattr(params, "output_dim", None) - original_parameter_memo[name] = (input_dim, output_dim, params.weight_loader) + original_parameter_memo[name] = ( + input_dim, + output_dim, + params.weight_loader, + ) class linear_adaptor(torch.nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True, \ - device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: super().__init__(in_features, out_features, bias, device, dtype) def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: @@ -378,34 +409,45 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: for name, module in model.named_modules(): bias_flag = False - if isinstance(module, QKVParallelLinear) or isinstance(module, MergedColumnParallelLinear) or \ - isinstance(module, RowParallelLinear) or isinstance(module, ColumnParallelLinear): + if ( + isinstance(module, QKVParallelLinear) + or isinstance(module, MergedColumnParallelLinear) + or isinstance(module, RowParallelLinear) + or isinstance(module, ColumnParallelLinear) + ): out_feature = module.weight.shape[0] in_feature = module.weight.shape[1] if getattr(module, "bias", False) != None: bias_flag = True weight_dtype = module.weight.dtype - torch_linear = linear_adaptor(in_features=in_feature, - out_features=out_feature, - bias=bias_flag, - dtype=weight_dtype) + torch_linear = linear_adaptor( + in_features=in_feature, + out_features=out_feature, + bias=bias_flag, + dtype=weight_dtype, + ) module_traversal = model - all_module_names = name.split('.') + all_module_names = name.split(".") all_module_names_except_last = all_module_names[:-1] for sub_module_name in all_module_names_except_last: module_traversal = module_traversal._modules[sub_module_name] - module_traversal._modules[all_module_names[-1]] = copy.deepcopy(torch_linear) + module_traversal._modules[all_module_names[-1]] = copy.deepcopy( + torch_linear + ) print("Optimized model =", model) loader = get_model_loader(llm.llm_engine.load_config) # pylint: disable=E1101 - weights_iterator = loader._get_weights_iterator(llm.llm_engine.model_config.model, - llm.llm_engine.model_config.revision, - fall_back_to_pt=True) + weights_iterator = loader._get_weights_iterator( + llm.llm_engine.model_config.model, + llm.llm_engine.model_config.revision, + fall_back_to_pt=True, + ) + + from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 - from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 params_dict = dict(model.named_parameters(remove_duplicate=False)) for name in params_dict.keys(): params = params_dict[name] @@ -423,11 +465,13 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: print("INC quantizing...") config = kwargs.pop("config", None) if config is None: - config = RtnConfig(compute_dtype="int8", - group_size=128, - scale_dtype="bf16", - weight_dtype="int4_clip", - bits=4) + config = RtnConfig( + compute_dtype="int8", + group_size=128, + scale_dtype="bf16", + weight_dtype="int4_clip", + bits=4, + ) print("using default RTNConfig = ", config) print("Using customized config = ", config) model = convert_to_quantized_model(model, config) @@ -462,7 +506,8 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: if model_type not in cls.model_type_list: logger.error( - "Can't support this model_type. Please set the correct model_type, supported model_type: {}".format( + "Can't support this model_type." + + "Please set the correct model_type, supported model_type: {}".format( cls.model_type_list ) ) @@ -488,8 +533,12 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: return model device_map = kwargs.get("device_map", "cpu") - use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False - use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False + use_cpu = ( + True if device_map == torch.device("cpu") or device_map == "cpu" else False + ) + use_xpu = ( + True if device_map == torch.device("xpu") or device_map == "xpu" else False + ) config = kwargs.pop("config", None) model_hub = kwargs.pop("model_hub", "huggingface") @@ -497,18 +546,30 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: quantization_config = kwargs.pop("quantization_config", None) if not isinstance(config, PretrainedConfig): if model_hub == "modelscope": - import modelscope # pylint: disable=E0401 - config = modelscope.AutoConfig.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=True) + import modelscope # pylint: disable=E0401 + + config = modelscope.AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) else: config, _ = AutoConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs, - ) - if kwargs.get("use_llm_runtime", None) is not None: + if quantization_config is not None and quantization_config.quant_method in [ + "sq" + ]: + use_neural_speed = False + elif ( + hasattr(config, "quantization_config") + and isinstance(config.quantization_config, dict) + and "quant_method" in config.quantization_config + and config.quantization_config["quant_method"] in ["sq"] + ): + use_neural_speed = False + elif kwargs.get("use_llm_runtime", None) is not None: use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu logger.warning( "use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead." @@ -539,30 +600,38 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: "Quantization_config loading failed. If you want to load saved " "low bit model, please check your quantizate_config.json." ) - elif use_neural_speed and not config.quantization_config["quant_method"] in ["dynamic", "static", "qat"]: + elif use_neural_speed and not config.quantization_config[ + "quant_method" + ] in ["dynamic", "static", "qat"]: if not os.path.exists(pretrained_model_name_or_path): from huggingface_hub import snapshot_download - pretrained_model_name_or_path = snapshot_download(repo_id=pretrained_model_name_or_path, - allow_patterns=["*.pt", "*.safetensors", "*.json", ".model"], - ) + + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + allow_patterns=["*.pt", "*.safetensors", "*.json", ".model"], + ) if quantization_config is None: - ConfigInit = {"rtn": RtnConfig, - "awq": AwqConfig, - "teq": TeqConfig, - "gptq": GPTQConfig, - "autoround": AutoRoundConfig, - } + ConfigInit = { + "rtn": RtnConfig, + "awq": AwqConfig, + "teq": TeqConfig, + "gptq": GPTQConfig, + "autoround": AutoRoundConfig, + } quantization_config = config.quantization_config - assert quantization_config.get("quant_method", None) in ConfigInit, \ - "Detect this model is not a low-bit model." - quantization_config = ConfigInit[quantization_config["quant_method"]].from_dict(quantization_config) + assert ( + quantization_config.get("quant_method", None) in ConfigInit + ), "Detect this model is not a low-bit model." + quantization_config = ConfigInit[ + quantization_config["quant_method"] + ].from_dict(quantization_config) logger.info("Loading Low Bits model by Neural Speed.") quantization_config.post_init_runtime() from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype=quantization_config.weight_dtype, alg=quantization_config.scheme, @@ -653,10 +722,16 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: else: quantization_config = RtnConfig( bits=4, - compute_dtype=torch.float32 if - (use_cpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), - weight_dtype="nf4" if use_cpu else "int4_fullrange", + compute_dtype=( + torch.float32 + if ( + use_cpu + and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16 + ) + else convert_dtype_torch2str(torch_dtype) + ), + weight_dtype="int4_clip" if use_cpu else "int4_fullrange", ) else: assert ( @@ -664,19 +739,26 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: and convert_dtype_str2torch(quantization_config.compute_dtype) == torch_dtype ), "Quantization_config.weight_dtype should be 'nf4' , 'int4', 'int4_fullrange', 'int4_clip', " - f"'fp4', 'fp4_e2m1' or 'fp4_e2m1_bnb' and compute_dtype should be {torch_dtype}." + f"'fp4', 'fp4_e2m1' and compute_dtype should be {torch_dtype}." elif load_in_8bit: if quantization_config is None: if use_neural_speed: quantization_config = RtnConfig( - compute_dtype="bf16" if CpuInfo().bf16 else "fp32", weight_dtype="int8" + compute_dtype="bf16" if CpuInfo().bf16 else "fp32", + weight_dtype="int8", ) else: quantization_config = RtnConfig( bits=8, - compute_dtype=torch.float32 if - (use_cpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), + compute_dtype=( + torch.float32 + if ( + use_cpu + and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16 + ) + else convert_dtype_torch2str(torch_dtype) + ), weight_dtype="int8", ) else: @@ -726,7 +808,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype=quantization_config.weight_dtype, alg=quantization_config.scheme, @@ -772,6 +854,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: **kwargs, ) model.config.update({"low_cpu_mem_usage": False}) + quantization_config.post_init_xpu() else: kwargs["low_cpu_mem_usage"] = True config.torchscript = ( @@ -786,6 +869,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: **kwargs, ) model.config.update({"low_cpu_mem_usage": True}) + quantization_config.post_init_cpu() model.eval() if use_xpu: @@ -823,7 +907,6 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: assert ( ipex.__version__ >= "2.2.0+cpu" ), "Please use Intel Extension for PyTorch >=2.2.0+cpu." - config.torchscript = True config.use_cache = True model = cls.ORIG_MODEL.from_pretrained( @@ -834,7 +917,6 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: torch_dtype=torch.float, **kwargs, ) - if ( not torch.cuda.is_available() or device_map == "cpu" @@ -848,254 +930,8 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: ) and model.config.model_type == "mpt": model.config.architectures = ["MptForCausalLM"] model.eval() - model_type = model.config.model_type.replace("_", "-") - logger.info("Applying SmoothQuant.") - # ipex.optimize_transformers - if quantization_config.ipex_opt_llm is None: - if model_type in IPEX_OPT_LLM_SUPPORTED: - quantization_config.ipex_opt_llm = True - logger.info( - "quantization_config.ipex_opt_llm set to True and ipex.optimize_transformers is used." - ) - logger.warning("The suggested transformers version is 4.38.1.") - else: - quantization_config.ipex_opt_llm = False - if quantization_config.ipex_opt_llm: - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) - model = ipex.optimize_transformers( - model.eval(), - quantization_config=qconfig, - dtype=torch.float32, - inplace=True, - deployment_mode=False, - ) - model.eval() - - # past_key_values - num_beams = quantization_config.num_beams - if quantization_config.ipex_opt_llm: - past_key_values = generate_dummy_past_key_values_for_opt_llm( - config=model.config, input_bs=1, num_beams=num_beams - ) - else: - past_key_values = generate_dummy_past_key_values( - config=model.config, input_bs=1 - ) - - # calibration function - calib_func = quantization_config.calib_func - tokenizer = quantization_config.tokenizer - if calib_func is None: - if quantization_config.tokenizer is None: - logger.error( - "Please provide the tokenizer or provide calib_func directly," - + " the following is how to get tokenizer. \n" - + " from transformer import AutoTokenizer \n" - + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" - ) - exit(0) - - from datasets import load_dataset - from torch.utils.data import DataLoader - - calib_dataset = quantization_config.calib_dataset - calib_shuffle = quantization_config.calib_shuffle - calib_iters = quantization_config.calib_iters - calib_padding = quantization_config.calib_padding - calib_len = quantization_config.calib_len - calib_pad_val = quantization_config.calib_pad_val - from torch.nn.functional import pad - - calib_dataset = load_dataset( - calib_dataset, - split=( - "test" - if calib_dataset in ["mbpp", "openai_humaneval"] - else "train" - ), - ) - if calib_shuffle: - calib_dataset = calib_dataset.shuffle(seed=42) - - def tokenize_function(examples): - if "code" in examples: - example = tokenizer(examples["code"]) - elif "prompt" in examples: - example = tokenizer(examples["prompt"]) - elif "text" in examples: - example = tokenizer(examples["text"]) - else: - logger.error( - "Please check dataset prompt identifier," - + " NeelNanda/pile-10k is default used calibration dataset." - ) - exit(0) - return example - - def collate_batch(batch): - position_ids_padded = [] - input_ids_padded = [] - last_ind = [] - attention_mask_padded = [] - for text in batch: - input_ids = text["input_ids"] - if not calib_padding: - input_ids = ( - input_ids[: int(calib_len)] - if len(input_ids) > int(calib_len) - else input_ids - ) # no_padding - else: - pad_len = calib_len - input_ids.shape[0] - input_ids = pad( - input_ids, (0, pad_len), value=calib_pad_val - ) - - last_ind.append(input_ids.shape[0] - 1) - attention_mask = torch.ones(len(input_ids)) - position_ids = torch.arange(len(input_ids)) - input_ids_padded.append(input_ids) - attention_mask_padded.append(attention_mask) - position_ids_padded.append(position_ids) - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - return ( - { - "input_ids": torch.vstack(input_ids_padded), - "attention_mask": torch.vstack(attention_mask_padded), - "position_ids": torch.vstack(position_ids_padded), - "past_key_values": past_key_values, - }, - torch.tensor(last_ind), - ) - else: - return ( - { - "input_ids": torch.vstack(input_ids_padded), - "attention_mask": torch.vstack(attention_mask_padded), - "past_key_values": past_key_values, - }, - torch.tensor(last_ind), - ) - - def collate_batch_for_chatglm(batch): - last_ind = [] - for text in batch: - input_ids = torch.vstack([text["input_ids"]]) - if re.search( - "THUDM/chatglm-6b", model.config.auto_map["AutoConfig"] - ): - input_ids = ( - input_ids[:, :calib_len] - if input_ids.shape[1] > calib_len - else input_ids - ) - eos = torch.tensor([130001, 130004]).repeat(1, 1) - input_ids = torch.cat((input_ids, eos), 1) - else: - input_ids = ( - input_ids[:, :calib_len] - if input_ids.shape[1] > calib_len - else input_ids - ) - prepared_inputs = model.prepare_inputs_for_generation(input_ids) - attention_mask = torch.ones_like(input_ids) - last_ind.append(input_ids.shape[1] - 1) - return ( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": prepared_inputs["position_ids"], - "past_key_values": past_key_values, - }, - torch.tensor(last_ind), - ) - - tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) - tokenized_dataset.set_format(type="torch", columns=["input_ids"]) - if model_type == "chatglm": - calib_dataloader = DataLoader( - tokenized_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_batch_for_chatglm, - ) - else: - calib_dataloader = DataLoader( - tokenized_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_batch, - ) - - def calib_func(model): - with torch.no_grad(): - for i, (inputs, last_ind) in enumerate(calib_dataloader): - if i >= calib_iters: - break - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - model( - input_ids=inputs["input_ids"], - past_key_values=inputs["past_key_values"], - position_ids=inputs["position_ids"], - attention_mask=inputs["attention_mask"], - ) - else: - model( - input_ids=inputs["input_ids"], - past_key_values=inputs["past_key_values"], - attention_mask=inputs["attention_mask"], - ) - - logger.info( - "The default calibration function is used, " - + "the calibration dataset is NeelNanda/pile-10k, " - + "batchsize is 1 and calibration iteration is 100." - ) - calib_func = calib_func - - # example_inputs - example_inputs = quantization_config.example_inputs - if example_inputs is None: - for i, (inputs, last_ind) in enumerate(calib_dataloader): - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - example_inputs = { - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"], - "position_ids": inputs["position_ids"], - "past_key_values": inputs["past_key_values"], - } - else: - example_inputs = { - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"], - "past_key_values": inputs["past_key_values"], - } - break - - # call inc sq - from neural_compressor import PostTrainingQuantConfig, quantization - - conf = PostTrainingQuantConfig( - backend=quantization_config.backend, # default is ipex - excluded_precisions=quantization_config.excluded_precisions, - op_type_dict=quantization_config.op_type_dict, - op_name_dict=quantization_config.op_name_dict, - recipes=quantization_config.recipes, - example_inputs=example_inputs, - ) - - model = quantization.fit( - model, - conf, - calib_func=calib_func, - calib_dataloader=( - calib_dataloader - if quantization_config.recipes["smooth_quant_args"]["alpha"] - == "auto" - else None - ), - ) + model = convert_to_smoothquant_model(model, quantization_config) logger.info("SmoothQuant done.") elif isinstance(quantization_config, DynamicQuantConfig): model = cls.ORIG_MODEL.from_pretrained( @@ -1239,7 +1075,6 @@ def collate_batch(batch): torch.tensor(last_ind), ) - tokenized_dataset = calib_dataset.map(tokenize_function, batched=True) tokenized_dataset.set_format(type="torch", columns=["input_ids"]) calib_dataloader = DataLoader( @@ -1263,7 +1098,6 @@ def calib_func(model): ) calib_func = calib_func - # call inc static quant from neural_compressor import PostTrainingQuantConfig, quantization @@ -1378,7 +1212,6 @@ def collate_batch(batch): torch.tensor(last_ind), ) - tokenized_dataset = train_dataset.map(tokenize_function, batched=True) tokenized_dataset.set_format(type="torch", columns=["input_ids"]) train_dataloader = DataLoader( @@ -1405,7 +1238,7 @@ def train_func(model): optimizer.zero_grad() loss.backward() optimizer.step() - print('Iteration [{}], Loss: {:.4f}'.format(i+1, loss)) + print("Iteration [{}], Loss: {:.4f}".format(i + 1, loss)) return model logger.info( @@ -1415,10 +1248,10 @@ def train_func(model): ) train_func = train_func - # call inc static quant from neural_compressor import QuantizationAwareTrainingConfig, quantization from neural_compressor.training import prepare_compression + conf = QuantizationAwareTrainingConfig( backend=quantization_config.backend, excluded_precisions=quantization_config.excluded_precisions, @@ -1430,7 +1263,9 @@ def train_func(model): model = compression_manager.model train_func(model) compression_manager.callbacks.on_train_end() - compression_manager.model.save_pretrained = types.MethodType(save_low_bit, model) + compression_manager.model.save_pretrained = types.MethodType( + save_low_bit, model + ) quantization_config.remove_redundant_parameters() compression_manager.model.quantization_config = quantization_config logger.info("Quant Aware Training done.") @@ -1441,7 +1276,7 @@ def train_func(model): from neural_speed import Model model = Model() - model.init( # pylint: disable=E1123 + model.init( # pylint: disable=E1123 pretrained_model_name_or_path, weight_dtype="fp32", use_quant=False, @@ -1522,7 +1357,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): kwarg_attn_imp = kwargs.pop("attn_implementation", None) # lm-eval device map is dictionary - device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map + device_map = ( + device_map[""] + if isinstance(device_map, dict) and "" in device_map + else device_map + ) if use_safetensors is None and not is_safetensors_available(): use_safetensors = False @@ -1538,8 +1377,12 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) token = use_auth_token - use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False - use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False + use_cpu = ( + True if device_map == torch.device("cpu") or device_map == "cpu" else False + ) + use_xpu = ( + True if device_map == torch.device("xpu") or device_map == "xpu" else False + ) user_agent = { "file_type": "model", @@ -1570,7 +1413,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): elif quantization_config["quant_method"] == "dynamic": quantization_config = DynamicQuantConfig.from_dict(quantization_config) elif quantization_config["quant_method"] == "qat": - quantization_config = QuantAwareTrainingConfig.from_dict(quantization_config) + quantization_config = QuantAwareTrainingConfig.from_dict( + quantization_config + ) + elif quantization_config["quant_method"] == "sq": + quantization_config = SmoothQuantConfig.from_dict(quantization_config) assert ( quantization_config is not None ), "Detect this model is not a low-bit model." @@ -1624,6 +1471,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # index of the files. is_sharded = False sharded_metadata = None + if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -1641,6 +1489,20 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): subfolder, _add_variant(WEIGHTS_NAME, variant), ) + # only for inc sq + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant("quantized_model.pt", variant), + ) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant("quantized_model.pt", variant), + ) elif os.path.isfile( os.path.join( pretrained_model_name_or_path, @@ -1709,11 +1571,15 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + if resolved_archive_file is None and filename == _add_variant( + SAFE_WEIGHTS_NAME, variant + ): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( pretrained_model_name_or_path, @@ -1734,9 +1600,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) resolved_archive_file = cached_file( - pretrained_model_name_or_path, filename, **cached_file_kwargs + pretrained_model_name_or_path, + filename, + **cached_file_kwargs, ) - if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + if resolved_archive_file is None and filename == _add_variant( + WEIGHTS_NAME, variant + ): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( pretrained_model_name_or_path, @@ -1755,7 +1625,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "token": token, } if variant is not None and has_file( - pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + pretrained_model_name_or_path, + WEIGHTS_NAME, + **has_file_kwargs, ): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" @@ -1779,7 +1651,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." ) from e - if is_local: logger.info(f"loading weights file {archive_file}") resolved_archive_file = archive_file @@ -1818,12 +1689,29 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.quant_method in ["static", "dynamic", "qat"]: model = model_class(config, *model_args, **kwargs) from neural_compressor.utils.pytorch import load + weights_file = os.path.join( - os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), WEIGHTS_NAME) + os.path.abspath(os.path.expanduser(pretrained_model_name_or_path)), + WEIGHTS_NAME, + ) q_model = load(weights_file, model, dataloader=None) del model return q_model + if quantization_config.quant_method in ["sq"]: + print("Loading SmoothQuant model from: ", pretrained_model_name_or_path) + from intel_extension_for_transformers.transformers.llm.quantization.sq_utils import ( + TSModelCausalLMForITREX, + ) + q_model = torch.jit.load( + os.path.join(pretrained_model_name_or_path, "quantized_model.pt") + ) + origin_model_type = config.model_type + if origin_model_type in ["chatglm", "qwen", "baichuan"]: + config.model_type = "qwen2" + q_model = TSModelCausalLMForITREX(q_model, config=config) + q_model.config.model_type = origin_model_type + return q_model dtype_orig = None if torch_dtype is not None: if isinstance(torch_dtype, str): @@ -1847,19 +1735,25 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): dtype_orig = model_class._set_default_torch_dtype(torch_dtype) if quantization_config.compute_dtype is None: if use_xpu: - quantization_config.compute_dtype = \ - "fp16" if (torch_dtype is None or - torch_dtype == torch.bfloat16) \ + quantization_config.compute_dtype = ( + "fp16" + if (torch_dtype is None or torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype) + ) else: - quantization_config.compute_dtype = \ - "fp32" if (torch_dtype is None or - (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) or - (torch_dtype == torch.float16)) \ + quantization_config.compute_dtype = ( + "fp32" + if ( + torch_dtype is None + or (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) + or (torch_dtype == torch.float16) + ) else convert_dtype_torch2str(torch_dtype) + ) else: - if ((not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") - or (use_cpu and quantization_config.compute_dtype == "fp16")): + if (not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") or ( + use_cpu and quantization_config.compute_dtype == "fp16" + ): quantization_config.compute_dtype = "fp32" if quantization_config.scale_dtype is None: @@ -1867,7 +1761,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.scale_dtype not in ["fp32", "fp16", "bf16"]: logger.warning("scale_dtype only supports fp32, bf16, fp16.") quantization_config.scale_dtype = "fp32" - logger.warning("fp32 scale_dtype is used, please change the config.json if you don't want to use it.") + logger.warning( + "fp32 scale_dtype is used, please change the config.json if you don't want to use it." + ) # weight dtype is higher priority than bits in config.json when both existed. if quantization_config.weight_dtype is None: @@ -1878,36 +1774,47 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config.weight_dtype = "int4_clip" logger.info( "{} quantization weight_dtype is used due to bits is 4 in config.json.".format( - quantization_config.weight_dtype) + quantization_config.weight_dtype ) + ) elif quantization_config.bits == 8: quantization_config.weight_dtype = "int8" logger.info( "{} quantization weight_dtype is used due to bits is 8 in config.json.".format( - quantization_config.weight_dtype) + quantization_config.weight_dtype ) + ) else: logger.warning("bits number only supports 4, 8.") quantization_config.weight_dtype = "int4_clip" logger.warning( - "int4_clip weight_dtype is used, please change the config.json if you don't want to use it.") + "int4_clip weight_dtype is used, please change the config.json if you don't want to use it." + ) else: - if quantization_config.weight_dtype not in ["int4_fullrange", - "int4_clip", - "int8", - "fp8_e5m2", - "fp8_e4m3", - "nf4", - "fp4_e2m1_bnb", - "fp4_e2m1"]: - logger.warning("Please provide the correct bits number or weight_dtype in config.json.") + if quantization_config.weight_dtype not in [ + "int4_fullrange", + "int4_clip", + "int8", + "fp8_e5m2", + "fp8_e4m3", + "nf4", + "fp4_e2m1_bnb", + "fp4_e2m1", + ]: + logger.warning( + "Please provide the correct bits number or weight_dtype in config.json." + ) raise ValueError( f"weight_dtype must be a string in " f"'int8', 'int4', 'int4_fullrange', 'int4_clip', 'nf4', " - f"'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8', 'fp8_e5m2, fp8_e4m3'" + f"'fp4', 'fp4_e2m1', 'fp8', 'fp8_e5m2, fp8_e4m3'" ) else: - logger.info("{} quantization weight_dtype is used.".format(quantization_config.weight_dtype)) + logger.info( + "{} quantization weight_dtype is used.".format( + quantization_config.weight_dtype + ) + ) init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts.append(init_empty_weights()) @@ -1918,9 +1825,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", - "fp4_e2m1", - "fp4_e2m1_bnb", - "nf4", ]: model = build_woq_model(model, quantization_config) else: @@ -1944,7 +1848,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): if is_ipex_available() and quantization_config.use_ipex: import intel_extension_for_pytorch as ipex - from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear + from intel_extension_for_pytorch.nn.modules import ( + WeightOnlyQuantizedLinear as ipex_linear, + ) + def replace_ipex_cpu_woq_linear(model, current_name=[]): for name, module in model.named_children(): current_name.append(name) @@ -1954,37 +1861,46 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): 8: ipex.quantization.WoqWeightDtype.INT8, } compute_dtype = { - "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. "bf16": ipex.quantization.WoqLowpMode.BF16, "fp16": ipex.quantization.WoqLowpMode.FP16, "int8": ipex.quantization.WoqLowpMode.INT8, - } - ipex_qconfig_mapping = ( - ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype[quantization_config.bits], - lowp_mode=compute_dtype[quantization_config.compute_dtype], - act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, - group_size=quantization_config.group_size, - ) + ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[quantization_config.compute_dtype], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, ) tmp_linear = torch.nn.Linear( module.in_features, module.out_features, - True if hasattr(module, "bias") else False - ) + True if hasattr(module, "bias") else False, + ) tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig target_linear = ipex_linear.from_float_and_int4_weight( - mod = tmp_linear, - qweight = state_dict.pop('.'.join(current_name) + ".ipex_weight"), - scales = state_dict.pop('.'.join(current_name) + ".ipex_scales"), - zero_points = state_dict.pop('.'.join(current_name) + ".ipex_zeros"), - bias = state_dict.pop('.'.join(current_name) + ".ipex_bias") \ - if '.'.join(current_name) + ".ipex_bias" in state_dict else None, - group_size = quantization_config.group_size, - g_idx = state_dict.pop('.'.join(current_name) + ".ipex_g_idx") \ - if '.'.join(current_name) + ".ipex_g_idx" in state_dict else None, + mod=tmp_linear, + qweight=state_dict.pop( + ".".join(current_name) + ".ipex_weight" + ), + scales=state_dict.pop( + ".".join(current_name) + ".ipex_scales" + ), + zero_points=state_dict.pop( + ".".join(current_name) + ".ipex_zeros" + ), + bias=( + state_dict.pop(".".join(current_name) + ".ipex_bias") + if ".".join(current_name) + ".ipex_bias" in state_dict + else None + ), + group_size=quantization_config.group_size, + g_idx=( + state_dict.pop(".".join(current_name) + ".ipex_g_idx") + if ".".join(current_name) + ".ipex_g_idx" in state_dict + else None + ), ) setattr(model, name, target_linear) else: @@ -2025,9 +1941,6 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): if quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", - "nf4", - "fp4_e2m1", - "fp4_e2m1_bnb", ] and not quantization_config.use_ipex: model = replace_linear( model, @@ -2036,8 +1949,9 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): empty_weights=True, ) - if (not use_xpu and torch_dtype == torch.float16) or (not use_xpu and not CpuInfo().bf16 - and torch_dtype == torch.bfloat16): + if (not use_xpu and torch_dtype == torch.float16) or ( + not use_xpu and not CpuInfo().bf16 and torch_dtype == torch.bfloat16 + ): model.to(dtype=torch.float32) # If it is a model with generation capabilities, attempt to load the generation config diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 65eb4158702..72953c95515 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -21,19 +21,20 @@ import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union -from .utility import QUANT_CONFIG, SPARSITY_CONFIG, LazyImport, logger + import transformers from transformers import BitsAndBytesConfig, PretrainedConfig +from .utility import QUANT_CONFIG, SPARSITY_CONFIG, LazyImport, logger + torch = LazyImport("torch") -@dataclass -class MixedPrecisionConfig: - dtype: str = "bfloat16" + if transformers.__version__ >= "4.32.0": from transformers.utils.quantization_config import QuantizationConfigMixin + QuantizationConfig = QuantizationConfigMixin else: QuantizationConfig = PretrainedConfig @@ -52,8 +53,17 @@ class QuantizationMethod(str, Enum): STATIC = "static" SmoothQuant = "sq" QuantAwareTraining = "qat" + MixedPrecision = "mp" + +class MixedPrecisionConfig(QuantizationConfig): + quant_method = QuantizationMethod.MixedPrecision + def __init__( + self, + dtype = "bfloat16" + ): + self.dtype = dtype class SparsityConfig(PretrainedConfig): def __init__( @@ -237,6 +247,7 @@ def get_config_dict( pretrained_model_name_or_path, _configuration_file=SPARSITY_CONFIG, **kwargs ) + class ITREXQuantizationConfigMixin(QuantizationConfig): """Mixin class for quantization config.""" @@ -258,7 +269,9 @@ def update(self, **kwargs): to_remove.append(key) # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + unused_kwargs = { + key: value for key, value in kwargs.items() if key not in to_remove + } return unused_kwargs def post_init_cpu(self): @@ -290,7 +303,6 @@ def post_init_cpu(self): if self.bits == 4 and self.weight_dtype not in [ "int4_clip", "nf4", - "fp4_e2m1_bnb", "fp4_e2m1", ]: self.weight_dtype = "int4_clip" @@ -308,25 +320,24 @@ def post_init_cpu(self): "int8", "int4_clip", "nf4", - "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3", ]: raise ValueError( f"weight_dtype must be a string in " - f"'int8', 'int4', 'int4_clip', 'nf4', 'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1', " + f"'int8', 'int4', 'int4_clip', 'nf4', 'fp4', 'fp4_e2m1', " f"'fp8', 'fp8_e5m2, fp8_e4m3'" ) if self.scale_dtype is not None and self.scale_dtype not in [ "fp32", "fp8_e8m0", - "bf16" + "bf16", ]: raise ValueError( - f"scale_dtype must be a string in 'fp32', 'fp8_e8m0', 'bf16' " - f"and fp8_e8m0 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'" + "scale_dtype must be a string in 'fp32', 'fp8_e8m0', 'bf16' " + "and fp8_e8m0 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'" ) elif self.scale_dtype is None: self.scale_dtype = "fp32" @@ -353,9 +364,9 @@ def post_init_cpu(self): or self.scale_dtype != "fp32" ): raise ValueError( - f"WeightOnlyQuantization doesn't support asym with " - f"compute_dtype int8 or weight_dtype float or scale_dtype non-fp32 now, " - f"please use sym scheme" + "WeightOnlyQuantization doesn't support asym with " + "compute_dtype int8 or weight_dtype float or scale_dtype non-fp32 now, " + "please use sym scheme" ) self.use_neural_speed = False @@ -384,10 +395,12 @@ def post_init_xpu(self): elif self.weight_dtype not in [ "int4_fullrange", ]: - raise ValueError(f"weight_dtype must be a string in 'int4_fullrange', but get {self.weight_dtype}.") + raise ValueError( + f"weight_dtype must be a string in 'int4_fullrange', but get {self.weight_dtype}." + ) if self.scale_dtype is not None and self.scale_dtype not in ["fp16"]: - raise ValueError(f"scale_dtype must be a string in 'fp16'") + raise ValueError("scale_dtype must be a string in 'fp16'") elif self.scale_dtype is None: self.scale_dtype = "fp16" @@ -420,7 +433,7 @@ def post_init_runtime(self): runtime_supported_weight_dtype = [ "int4", "int4_clip", # int4_clip will merge to int4 in next release. - "int4_fullrange", # int4_fullrange will merge to int4 in next release. + "int4_fullrange", # int4_fullrange will merge to int4 in next release. "int8", "fp8", "fp8_e5m2", @@ -542,13 +555,48 @@ def to_json_file( writer.write(self.to_json_string(use_diff=use_diff)) def remove_redundant_parameters(self): - remove_parameters = ["calib_dataloader", "dataset", "calib_func", "calib_iters", "calib_len", - "double_quant_scale_dtype", "use_double_quant", "mse_range", "scheme", "tokenizer", "use_ggml", - "use_neural_speed", "use_quant", "layer_wise", "blocksize", "nsamples", "max_input_length", "static_groups", - "lr", "minmax_lr", "iters", "use_quant_input", "device", "calib_dataset", "calib_pad_val", "calib_shuffle", - "calib_padding", "example_inputs", "excluded_precisions", "op_name_dict", "op_type_dict", "train_dataloader", - "train_func", "train_iters", "train_len", "train_padding", "train_dataset", "train_pad_val", "train_shuffle", - "train_batch_size"] + remove_parameters = [ + "calib_dataloader", + "dataset", + "calib_func", + "calib_iters", + "calib_len", + "double_quant_scale_dtype", + "use_double_quant", + "mse_range", + "scheme", + "tokenizer", + "use_ggml", + "use_neural_speed", + "use_quant", + "layer_wise", + "blocksize", + "nsamples", + "max_input_length", + "static_groups", + "lr", + "minmax_lr", + "iters", + "use_quant_input", + "device", + "calib_dataset", + "calib_pad_val", + "calib_shuffle", + "calib_padding", + "example_inputs", + "excluded_precisions", + "op_name_dict", + "op_type_dict", + "train_dataloader", + "train_func", + "train_iters", + "train_len", + "train_padding", + "train_dataset", + "train_pad_val", + "train_shuffle", + "train_batch_size", + ] for parameter in remove_parameters: if hasattr(self, parameter): delattr(self, parameter) @@ -611,24 +659,25 @@ def get_config_dict( pretrained_model_name_or_path, _configuration_file=cf, **kwargs ) + class QuantAwareTrainingConfig(ITREXQuantizationConfigMixin): def __init__( - self, - backend="default", - tokenizer=None, - train_dataset="NeelNanda/pile-10k", - train_dataloader=None, - train_func=None, - train_shuffle=True, - train_iters=100, - train_padding=True, - train_batch_size=8, - train_len=512, - train_pad_val=1, - op_name_dict=None, - op_type_dict=None, - excluded_precisions=[], - **kwargs, + self, + backend="default", + tokenizer=None, + train_dataset="NeelNanda/pile-10k", + train_dataloader=None, + train_func=None, + train_shuffle=True, + train_iters=100, + train_padding=True, + train_batch_size=8, + train_len=512, + train_pad_val=1, + op_name_dict=None, + op_type_dict=None, + excluded_precisions=[], + **kwargs, ): self.quant_method = QuantizationMethod.QuantAwareTraining self.backend = backend @@ -649,35 +698,36 @@ def __init__( class DynamicQuantConfig(ITREXQuantizationConfigMixin): def __init__( - self, - excluded_precisions=[], - op_name_dict=None, - op_type_dict=None, - **kwargs, + self, + excluded_precisions=[], + op_name_dict=None, + op_type_dict=None, + **kwargs, ): self.quant_method = QuantizationMethod.DYNAMIC self.excluded_precisions = excluded_precisions self.op_name_dict = op_name_dict self.op_type_dict = op_type_dict + class StaticQuantConfig(ITREXQuantizationConfigMixin): def __init__( - self, - backend="default", - tokenizer=None, - calib_dataset="NeelNanda/pile-10k", - calib_dataloader=None, - calib_func=None, - calib_shuffle=True, - calib_iters=100, - calib_padding=False, - calib_len=512, - calib_pad_val=1, - op_name_dict=None, - op_type_dict=None, - excluded_precisions=[], - example_inputs=None, - **kwargs, + self, + backend="default", + tokenizer=None, + calib_dataset="NeelNanda/pile-10k", + calib_dataloader=None, + calib_func=None, + calib_shuffle=True, + calib_iters=100, + calib_padding=False, + calib_len=512, + calib_pad_val=1, + op_name_dict=None, + op_type_dict=None, + excluded_precisions=[], + example_inputs=None, + **kwargs, ): self.quant_method = QuantizationMethod.STATIC self.backend = backend @@ -695,93 +745,97 @@ def __init__( self.excluded_precisions = excluded_precisions self.example_inputs = example_inputs -class SmoothQuantConfig(StaticQuantConfig): + +class SmoothQuantConfig(ITREXQuantizationConfigMixin): def __init__( - self, - backend="ipex", - tokenizer=None, - calib_dataset="NeelNanda/pile-10k", - calib_dataloader=None, - calib_func=None, - calib_shuffle=True, - calib_iters=100, - calib_padding=False, - calib_len=512, - calib_pad_val=1, - op_name_dict=None, - op_type_dict=None, - excluded_precisions=[], - example_inputs=None, - ipex_opt_llm=None, - alpha=0.5, - num_beams=1, - recipes={"smooth_quant": True, "smooth_quant_args":{"alpha":0.5}}, - **kwargs, + self, + tokenizer=None, + dataset="NeelNanda/pile-10k", + alpha=0.5, + scale_sharing=False, + init_alpha=0.5, + alpha_min=0.0, + alpha_max=1.0, + alpha_step=0.1, + shared_criterion="max", + do_blockwise=False, + auto_alpha_args=None, + n_samples=100, + seq_len=512, + excluded_precisions=[], + ipex_opt_llm=None, + num_beams=1, + shuffle=False, + padding=False, + **kwargs, ): - super().__init__( - backend=backend, - tokenizer=tokenizer, - calib_dataset=calib_dataset, - calib_dataloader=calib_dataloader, - calib_func=calib_func, - calib_shuffle=calib_shuffle, - calib_iters=calib_iters, - calib_padding=calib_padding, - calib_len=calib_len, - calib_pad_val=calib_pad_val, - op_name_dict=op_name_dict, - op_type_dict=op_type_dict, - excluded_precisions=excluded_precisions, - example_inputs=example_inputs, - ) self.quant_method = QuantizationMethod.SmoothQuant - self.ipex_opt_llm = ipex_opt_llm + self.dataset = dataset + self.tokenizer = tokenizer self.alpha = alpha + self.scale_sharing = scale_sharing + self.init_alpha = init_alpha + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.alpha_step = alpha_step + self.shared_criterion = shared_criterion + self.do_blockwise = do_blockwise + self.auto_alpha_args = auto_alpha_args + self.n_samples = n_samples + self.seq_len = seq_len + self.ipex_opt_llm = ipex_opt_llm self.num_beams = num_beams - self.recipes = recipes + self.shuffle = shuffle + self.padding = padding + self.excluded_precisions = excluded_precisions + self.batch_size = kwargs.pop("batch_size", 1) + class RtnConfig(ITREXQuantizationConfigMixin): def __init__( self, bits: int = 4, group_size: int = 32, + group_dim: int = 1, compute_dtype: Any = None, weight_dtype: Any = None, scale_dtype: Any = None, + use_full_range: bool = False, mse_range: bool = False, - use_double_quant=False, - double_quant_scale_dtype=None, # reserve for double quant + use_double_quant: bool = False, + double_quant_dtype: str = "int", + double_quant_bits: int = 8, + double_quant_use_sym: bool = False, + double_quant_group_size: int = 256, sym: bool = True, layer_wise: bool = False, use_ggml: bool = False, use_quant: bool = True, use_neural_speed: bool = False, - llm_int8_skip_modules=None, **kwargs, ): self.quant_method = QuantizationMethod.RTN self.bits = bits + self.use_full_range = use_full_range self.mse_range = mse_range self.compute_dtype = compute_dtype self.weight_dtype = weight_dtype self.scale_dtype = scale_dtype self.group_size = group_size + self.group_dim = group_dim self.layer_wise = layer_wise self.sym = sym self.scheme = "sym" if self.sym else "asym" self.use_double_quant = use_double_quant - self.double_quant_scale_dtype = double_quant_scale_dtype - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.double_quant_dtype = double_quant_dtype + self.double_quant_bits = double_quant_bits + self.double_quant_use_sym = double_quant_use_sym + self.double_quant_group_size = double_quant_group_size + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", []) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed self.device = kwargs.get("device", "auto") - self.calib_dataloader = None - self.dataset = None - self.calib_func = None - self.calib_iters = None self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: @@ -811,6 +865,7 @@ def __init__( bits: int = 4, tokenizer: Any = None, dataset: str = "NeelNanda/pile-10k", + batch_size: int = 8, group_size: int = 32, compute_dtype: Any = None, weight_dtype: Any = None, @@ -821,15 +876,14 @@ def __init__( blocksize: int = 128, damp_percent: float = 0.1, desc_act: bool = False, - nsamples: int = 128, - max_input_length: Optional[int] = None, + n_samples: int = 128, + seq_len: int = 2048, static_groups: bool = False, true_sequential: bool = False, layer_wise: bool = False, use_ggml: bool = False, use_quant: bool = True, use_neural_speed: bool = False, - llm_int8_skip_modules=None, **kwargs, ): @@ -841,6 +895,7 @@ def __init__( self.bits = bits self.tokenizer = tokenizer self.dataset = dataset + self.batch_size = batch_size self.compute_dtype = compute_dtype self.weight_dtype = weight_dtype self.scale_dtype = scale_dtype @@ -848,24 +903,19 @@ def __init__( self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype self.blocksize = blocksize - self.nsamples = nsamples + self.n_samples = n_samples self.group_size = group_size self.damp_percent = damp_percent self.desc_act = desc_act self.static_groups = static_groups self.true_sequential = true_sequential self.layer_wise = layer_wise - self.max_input_length = max_input_length - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.seq_len = seq_len + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", []) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed self.device = kwargs.get("device", "auto") - self.calib_dataloader = kwargs.get("calib_dataloader", None) - self.calib_func = kwargs.get("calib_func", None) - self.calib_iters = kwargs.get("calib_iters", 100) self.scheme = "sym" if self.sym else "asym" if isinstance(compute_dtype, torch.dtype): @@ -919,6 +969,7 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + class AwqConfig(ITREXQuantizationConfigMixin): def __init__( self, @@ -929,14 +980,17 @@ def __init__( compute_dtype: Any = None, weight_dtype: Any = None, scale_dtype: Any = None, + layer_wise: bool = False, + n_samples: int = 128, + seq_len: int = 2048, + auto_scale: bool = True, + auto_clip: bool = True, use_double_quant=False, double_quant_scale_dtype=None, # reserve for double quant zero_point: bool = True, - mse_range: bool = False, use_ggml: bool = False, use_quant: bool = True, use_neural_speed: bool = False, - llm_int8_skip_modules=None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -948,21 +1002,21 @@ def __init__( self.scale_dtype = scale_dtype self.group_size = group_size self.zero_point = zero_point - self.mse_range = mse_range + self.auto_scale = auto_scale + self.auto_clip = auto_clip + self.layer_wise = layer_wise + self.n_samples = n_samples + self.seq_len = seq_len self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", []) self.use_ggml = use_ggml self.use_quant = use_quant self.use_neural_speed = use_neural_speed self.device = kwargs.get("device", "auto") - self.calib_dataloader = kwargs.get("calib_dataloader", None) - self.calib_func = kwargs.get("calib_func", None) - self.calib_iters = kwargs.get("calib_iters", 100) self.scheme = "asym" if self.zero_point else "sym" self.sym = True if not self.zero_point else False + self.batch_size = kwargs.pop("batch_size", 8) self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: @@ -986,6 +1040,7 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + class TeqConfig(ITREXQuantizationConfigMixin): def __init__( self, @@ -996,12 +1051,15 @@ def __init__( compute_dtype: Any = None, weight_dtype: Any = None, scale_dtype: Any = None, + layer_wise: bool = False, + absorb_to_layer: dict = {}, + n_samples: int = 128, + seq_len: int = 2048, use_double_quant=False, double_quant_scale_dtype=None, # reserve for double quant sym: bool = True, use_ggml: bool = False, use_neural_speed: bool = False, - llm_int8_skip_modules=None, **kwargs, ): self.quant_method = QuantizationMethod.TEQ @@ -1012,19 +1070,19 @@ def __init__( self.weight_dtype = weight_dtype self.scale_dtype = scale_dtype self.group_size = group_size + self.absorb_to_layer = absorb_to_layer self.sym = sym self.scheme = "sym" if self.sym else "asym" + self.layer_wise = layer_wise + self.n_samples = n_samples + self.seq_len = seq_len self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", []) self.use_ggml = use_ggml self.use_neural_speed = use_neural_speed self.device = kwargs.get("device", "auto") - self.calib_dataloader = kwargs.get("calib_dataloader", None) - self.calib_func = kwargs.get("calib_func", None) - self.calib_iters = kwargs.get("calib_iters", 100) + self.batch_size = kwargs.pop("batch_size", 8) self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: @@ -1048,6 +1106,7 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict + class AutoRoundConfig(ITREXQuantizationConfigMixin): def __init__( self, @@ -1063,12 +1122,13 @@ def __init__( sym: bool = False, lr: float = None, minmax_lr: float = None, - disable_quanted_input: bool = False, - nsamples: int = 512, - iters: int = None, + disable_quanted_input: bool = True, + n_samples: int = 128, + seq_len: int = 2048, + iters: int = 200, + quant_lm_head: bool = False, use_ggml: bool = False, use_neural_speed: bool = False, - llm_int8_skip_modules=None, **kwargs, ): @@ -1086,20 +1146,19 @@ def __init__( self.sym = sym self.use_double_quant = use_double_quant self.double_quant_scale_dtype = double_quant_scale_dtype - self.nsamples = nsamples + self.n_samples = n_samples self.group_size = group_size self.lr = lr self.minmax_lr = minmax_lr self.disable_quanted_input = disable_quanted_input - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.iters = iters + self.seq_len = seq_len + self.quant_lm_head = quant_lm_head + self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", []) self.use_ggml = use_ggml self.use_neural_speed = use_neural_speed + self.batch_size = kwargs.pop("batch_size", 8) self.device = kwargs.get("device", "auto") - self.calib_dataloader = kwargs.get("calib_dataloader", None) - self.calib_len = kwargs.get("calib_len", 2048) - self.calib_func = kwargs.get("calib_func", None) calib_iters = kwargs.get("calib_iters", None) if iters is not None: self.calib_iters = iters diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 78fe5f2063d..527d8b097ff 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -18,7 +18,6 @@ import argparse import os -from typing import Optional, Tuple from neural_compressor.utils import logger from neural_compressor.utils.utility import LazyImport, CpuInfo from intel_extension_for_transformers.tools.utils import is_ipex_available @@ -96,407 +95,3 @@ def __init__(self) -> None: self.dataset = dataloader.dataset return INCDataLoader() - - -def generate_dummy_past_key_values(config, input_bs): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 0, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 0, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 0, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - nb_pkv = 2 - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - - if config.model_type == "bloom": - shape_key = (input_bs * num_attention_heads, d_k, 1) - shape_value = (input_bs * num_attention_heads, 1, d_k) - key = torch.ones(size=shape_key) - value = torch.ones(size=shape_value) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) - for _ in range(num_layers) - ) - return past_key_values - elif config.model_type == "gpt_bigcode": - new_shape = [input_bs, 0, d_k * 2] - dummy_tensor = torch.zeros(size=new_shape) - past_key_values = tuple([dummy_tensor] * num_layers) - return past_key_values - elif config.model_type == "falcon": - new_shape = [input_bs, 1, 0, d_k] - else: - new_shape = [input_bs, num_key_value_heads, 0, d_k] - past_key_values = [ - ( - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -def generate_dummy_past_key_values_for_inference(config, input_bs): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 0, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 0, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 0, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - nb_pkv = 2 - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - - if config.model_type == "bloom": - shape_key = (input_bs * num_attention_heads, d_k, 0) - shape_value = (input_bs * num_attention_heads, 0, d_k) - key = torch.empty(size=shape_key) - value = torch.empty(size=shape_value) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) - for _ in range(num_layers) - ) - return past_key_values - elif config.model_type == "gpt_bigcode": - new_shape = [input_bs, 0, d_k * 2] - dummy_tensor = torch.zeros(size=new_shape) - past_key_values = tuple([dummy_tensor] * num_layers) - return past_key_values - elif config.model_type == "falcon": - new_shape = [input_bs, 1, 0, d_k] - else: - new_shape = [input_bs, num_key_value_heads, 0, d_k] - past_key_values = [ - ( - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): - """Generate the dummy past_key_values.""" - from optimum.utils import NormalizedConfigManager - if config.model_type == "qwen": - new_shape = [ - input_bs, - 1, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "baichuan": - new_shape = [ - input_bs, - config.num_attention_heads, - 1, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_hidden_layers - elif config.model_type == "chatglm": - new_shape = [ - 1, - input_bs, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ] - num_layers = config.num_layers - else: - normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.model_type - )(config) - num_layers = normalized_config.num_layers - num_attention_heads = normalized_config.num_attention_heads - hidden_size = normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - num_key_value_heads = num_attention_heads - nb_pkv = 2 - if hasattr(normalized_config, "num_key_value_heads"): - num_key_value_heads = normalized_config.num_key_value_heads - if hasattr(normalized_config, "multi_query_group_num"): - num_key_value_heads = normalized_config.multi_query_group_num - if config.model_type == "bloom": - for nb_pkv in range(nb_pkv): - if nb_pkv % 2 == 0: - new_shape = [input_bs * num_key_value_heads, d_k, 1] - else: - new_shape = [input_bs * num_key_value_heads, 1, d_k] - - else: - new_shape = [input_bs, num_key_value_heads, 1, d_k] - - beam_idx_tmp = torch.zeros( - (2048, int(input_bs * num_beams)), dtype=torch.long - ).contiguous() - past_key_values = [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros(size=new_shape).contiguous(), - torch.zeros(size=new_shape).contiguous(), - beam_idx_tmp, - ) - for _ in range(num_layers) - ] - return tuple(past_key_values) - -IPEX_OPT_LLM_SUPPORTED_DICT = { - "2.2": ["gptj", "opt", "llama", "falcon", "chatglm", "baichuan", "gpt-neox"], - "2.3": [ - "gptj", - "opt", - "llama", - "falcon", - "chatglm", - "baichuan", - "qwen", - "bloom", - "codegen", - "gptbigcode", - "t5", - "mixtral", - "mpt", - ], -} - -MODEL_TYPES_REQUIRING_POSITION_IDS = { - "codegen", - "gpt2", - "gpt-bigcode", - "gpt-neo", - "gpt-neox", - "gptj", - "imagegpt", - "llama", - "mistral", - "chatglm", -} - -if is_ipex_available() and ipex.__version__ == "2.2.0+cpu": - logger.info( - "ipex.llm.optimize by 2.2.0 version supported model family: {}".format( - ",".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.2"]) - ) - ) - logger.info( - "The recommended transformers version is 4.35.2 if you used IPEX 2.2.0 version." - ) - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.2"] -elif is_ipex_available() and ipex.__version__ == "2.3.0+cpu": - logger.info( - "ipex.llm.optimize by 2.3.0 version supported model family: {}".format( - ", ".join(IPEX_OPT_LLM_SUPPORTED_DICT["2.3"]) - ) - ) - logger.info( - "The recommended transformers version is 4.38.1 if you used IPEX 2.3.0 version." - ) - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] -else: - logger.warning("Please check the intel_extension_for_pytorch version is 2.3.0+cpu.") - IPEX_OPT_LLM_SUPPORTED = IPEX_OPT_LLM_SUPPORTED_DICT["2.3"] - -def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4): - """Generate the dummy example inputs.""" - prompt = "Welcome to use Intel Extension for Transformers." - prompt = [prompt] * batch_size - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - model_type = model_config.model_type.replace("_", "-") - if model_type in IPEX_OPT_LLM_SUPPORTED: - past_key_values = generate_dummy_past_key_values_for_opt_llm( - config=model_config, - input_bs=batch_size, - num_beams=num_beams - ) - else: - past_key_values = generate_dummy_past_key_values(config=model_config, input_bs=batch_size) - - input_ids = input_ids[:, :512] - attention_mask = torch.ones(input_ids.shape) - position_ids = torch.arange(input_ids.shape[1]).repeat(batch_size, 1) - - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: - example_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values - } - else: - example_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values - } - return example_inputs - - -def make_torchscript_model(model, json_file_path, example_inputs): - """Recover ipex model from JSON file. - - Args: - model (object): fp32 model need to do quantization. - json_file_path (json): configuration JSON file for ipex. - example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function. - - Returns: - (object): quantized model - """ - - ipex = LazyImport("intel_extension_for_pytorch") - from torch.ao.quantization.observer import MinMaxObserver - - if ipex.__version__ >= "2.1.100": - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) - else: - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver()) - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True) - else: - model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True) - model.load_qconf_summary(qconf_summary=json_file_path) - model = ipex.quantization.convert(model, inplace=True) - model.eval() - with torch.no_grad(): - try: - if isinstance(example_inputs, dict): - # pylint: disable=E1120,E1123 - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) - else: - model = torch.jit.trace(model, example_inputs) - model = torch.jit.freeze(model.eval()) - except: - if isinstance(example_inputs, dict): - # pylint: disable=E1120,E1123 - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) - else: - model = torch.jit.trace(model, example_inputs, strict=False) - model = torch.jit.freeze(model.eval()) - if isinstance(example_inputs, dict): - model(**example_inputs) - model(**example_inputs) - elif isinstance(example_inputs, tuple) or isinstance(example_inputs, list): - model(*example_inputs) - model(*example_inputs) - else: - model(example_inputs) - model(example_inputs) - return model - -def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remote_code=False): - """Recover ipex model from JSON file. - - Args: - model (object): fp32 model need to do quantization. - json_file_path (json): configuration JSON file for ipex. - trust_remote_code (bool): trust remote code. - - Returns: - (object): quantized model - """ - from transformers import AutoModelForCausalLM - - # ipex recovered int8 model from configure.json requests float32 model input and on cpu device. - user_model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path, - trust_remote_code=trust_remote_code).float() - if user_model.config.model_type in IPEX_OPT_LLM_SUPPORTED: - import intel_extension_for_pytorch as ipex - qconfig = ipex.quantization.default_static_qconfig_mapping - user_model = ipex.optimize_transformers( - user_model.eval(), - dtype=torch.float, - inplace=True, - quantization_config=qconfig, - deployment_mode=False, - ) - - # tokenizer - if user_model.config.model_type == "llama": - from transformers import LlamaTokenizer - tokenizer = LlamaTokenizer.from_pretrained(user_model.config.name_or_path) - else: - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - user_model.config.name_or_path, trust_remote_code=trust_remote_code - ) - - # example_inputs - example_inputs = get_example_inputs(user_model.config, tokenizer=tokenizer) - - # pylint: disable=E0611 - user_model.config.torchscript = True - config = user_model.config - user_model = make_torchscript_model(user_model, json_file_path, example_inputs) - import intel_extension_for_pytorch as ipex - from intel_extension_for_transformers.transformers.llm.evaluation.models import ( - TSModelCausalLMForITREX, - ) - origin_model_type = config.model_type - if origin_model_type in ["chatglm", "qwen", "baichuan"]: - config.model_type = "qwen2" - user_model = TSModelCausalLMForITREX(user_model, config=config) - user_model.config.model_type = origin_model_type - return user_model diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index 7f5911855cb..264d924efad 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -251,8 +251,7 @@ def test_quantization_for_llm(self): q_model.save_pretrained("./saved_results") output = q_model(dummy_input) self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04)) - q_model = AutoModelForCausalLM.from_pretrained("./saved_results" - ) + q_model = AutoModelForCausalLM.from_pretrained("./saved_results") output = q_model(dummy_input) self.assertTrue(isclose(float(output[0][0][0][0]), 0.17140813171863556, rel_tol=1e-04)) # Static quant @@ -290,32 +289,30 @@ def test_quantization_for_llm(self): # Smoothquant sq_config = SmoothQuantConfig( tokenizer=tokenizer, # either two of one, tokenizer or calib_func - calib_iters=2, + n_samples=2, ipex_opt_llm=False ) q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=sq_config, - use_neural_speed=False ) - self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule)) + self.assertTrue(isinstance(q_model, torch.jit.ScriptModule)) # Smoothquant auto - recipes = { - "smooth_quant": True, - "smooth_quant_args": { "alpha": "auto", "auto_alpha_args":{"alpha_max": 0.6, - "alpha_min":0.5, "alpha_step":0.1, "shared_criterion": "mean", "do_blockwise": False}}, - } sq_config = SmoothQuantConfig( tokenizer=tokenizer, # either two of one, tokenizer or calib_func - calib_iters=2, - recipes=recipes, + n_samples=2, + alpha="auto", + alpha_max=0.6, + alpha_min=0.5, + alpha_step=0.1, + shared_criterion="mean", + do_blockwise=False, ipex_opt_llm=False ) q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=sq_config, - use_neural_speed=False ) - self.assertTrue(isinstance(q_model.model, torch.jit.ScriptModule)) + self.assertTrue(isinstance(q_model, torch.jit.ScriptModule)) # weight-only # RTN @@ -331,7 +328,9 @@ def test_quantization_for_llm(self): # AWQ woq_config = AwqConfig(bits=4, zero_point=False, - calib_iters=5, + n_samples=5, + batch_size=1, + seq_len=512, tokenizer=tokenizer ) @@ -341,19 +340,23 @@ def test_quantization_for_llm(self): ) woq_model.eval() output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.18019595742225647 , rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.20071472227573395 , rel_tol=1e-04)) + + # # TEQ + # need INC fix. + # woq_config = TeqConfig(bits=4, + # n_samples=5, + # batch_size=1, + # seq_len=512, + # tokenizer=tokenizer + # ) + # woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + # quantization_config=woq_config, + # use_neural_speed=False + # ) + # woq_model.eval() + # output = woq_model(dummy_input) - # TEQ - woq_config = TeqConfig(bits=4, - calib_iters=5, - tokenizer=tokenizer, - ) - woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, - quantization_config=woq_config, - use_neural_speed=False - ) - woq_model.eval() - output = woq_model(dummy_input) # fp8 woq_config = RtnConfig(bits=8, weight_dtype="fp8_e5m2", scale_dtype="fp8_e8m0") @@ -383,7 +386,7 @@ def test_quantization_for_llm(self): ) bit4_model.eval() output = bit4_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.18726778030395508, rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.17631684243679047, rel_tol=1e-04)) # load_in_8bit bit8_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, @@ -402,9 +405,10 @@ def test_quantization_for_llm(self): desc_act=False, damp_percent=0.01, blocksize=32, - nsamples=3, - max_input_length=256, + n_samples=3, + seq_len=256, tokenizer=tokenizer, + batch_size=1 ) woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=woq_config, @@ -412,13 +416,13 @@ def test_quantization_for_llm(self): ) woq_model.eval() output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.17126554250717163, rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.1800851970911026, rel_tol=1e-04)) # AUTOROUND woq_config = AutoRoundConfig(bits=4, - weight_dtype="int4_clip", - nsamples=128, - calib_len=32, + weight_dtype="int4_clip", + n_samples=128, + seq_len=32, iters=5, tokenizer=tokenizer ) @@ -429,7 +433,7 @@ def test_quantization_for_llm(self): woq_model.eval() output = woq_model(dummy_input) if CpuInfo().bf16: - self.assertTrue(isclose(float(output[0][0][0][0]), 0.169921875, rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.1513671875, rel_tol=1e-04)) def test_export(self): # test model with model_id diff --git a/tests/CI/test_weight_only.py b/tests/CI/test_weight_only.py index bca5f2ba169..eb73bb96e5e 100644 --- a/tests/CI/test_weight_only.py +++ b/tests/CI/test_weight_only.py @@ -41,6 +41,11 @@ from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search from intel_extension_for_transformers.transformers import RtnConfig +import random +random.seed(1234) +torch.manual_seed(1234) +import numpy as np +np.random.seed(1234) class DummyDataset(data.Dataset): def __init__(self): @@ -124,29 +129,7 @@ def test_int8(self): output = model(activation) config = RtnConfig(bits=8, weight_dtype="int8", group_size=32) - convert_to_quantized_model(model, config) - output_quant = model(activation) - print(output) - print(output_quant) - assert torch.allclose(output, output_quant, rtol=0.01) - - def test_int4(self): - raw_wei = torch.rand(2, 32, dtype=torch.float) - compress_wei = qbits.quantize_to_packed_weight( - raw_wei, True, 32, "fp32", "nf4", "fp32", False) - revert_wei = torch.zeros(2, 32, dtype=torch.float) - qbits.dequantize_packed_weight(compress_wei, revert_wei, True, - "fp32", "nf4", "fp32") - for bias in [True, False]: - model = M(with_bias=bias) - with torch.no_grad(): - model.linear.weight = torch.nn.Parameter(revert_wei) - activation = torch.rand(1, 5, 32, dtype=torch.float) - output = model(activation) - with torch.no_grad(): - model.linear.weight = torch.nn.Parameter(raw_wei) - config = RtnConfig( - bits=4, weight_dtype="nf4", group_size=32) + config.post_init_cpu() convert_to_quantized_model(model, config) output_quant = model(activation) print(output) @@ -226,8 +209,9 @@ def test_auto_model_saving_loading(self): self.assertTrue(len(module_list) > 0) def test_nf4_training(self): + quantization_config = RtnConfig(bits=4, weight_dtype="nf4", scale_dtype="fp32") model = AutoModelForCausalLM.from_pretrained( - llama_model_path, load_in_4bit=True, use_neural_speed=False) + llama_model_path, quantization_config=quantization_config, use_neural_speed=False) peft_config = LoraConfig( r=8, lora_alpha=16, diff --git a/tests/CI/test_weight_only_gpu.py b/tests/CI/test_weight_only_gpu.py index a73715a222e..b34d30a6b83 100644 --- a/tests/CI/test_weight_only_gpu.py +++ b/tests/CI/test_weight_only_gpu.py @@ -23,7 +23,7 @@ from intel_extension_for_transformers.transformers import GPTQConfig, RtnConfig from math import isclose from transformers import AutoTokenizer -from intel_extension_for_transformers.tools.utils import is_intel_gpu_available, is_ipex_available +from intel_extension_for_transformers.tools.utils import is_intel_gpu_available, is_ipex_available, _ipex_version from torch.utils.data import DataLoader @@ -68,7 +68,7 @@ def forward(self, x): return self.linear(x) -@unittest.skipIf(not is_ipex_available() or not is_intel_gpu_available(), +@unittest.skipIf(not is_ipex_available() or not _ipex_version >= "2.3.10" or not is_intel_gpu_available(), "There is no Intel GPU in this machine, skip this test!") class TestArcWeightOnly(unittest.TestCase):