diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index d54dd5f127f..f41f05df7c2 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -28,22 +28,17 @@ "--num_beams", default=1, type=int, help="number of beams" ) parser.add_argument("--output_dir", nargs="?", default="./saved_results") -parser.add_argument("--int8", action="store_true") -parser.add_argument( - "--int8_bf16_mixed", - action="store_true", - help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", -) -parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model") # ============Benchmark configs============== parser.add_argument("--benchmark", action="store_true") +parser.add_argument("--eval_batch_size", default=1, type=int, + help="batch size num.") parser.add_argument("--do_profiling", action="store_true") parser.add_argument("--profile_token_latency", action="store_true") -parser.add_argument("--iters", default=10, type=int, help="num iter") +parser.add_argument("--benchmark_iters", default=10, type=int, help="num iter") parser.add_argument("--num_warmup", default=3, type=int, help="num warmup") # ============Accuracy configs============== parser.add_argument("--accuracy", action="store_true") -parser.add_argument("--batch_size", default=1, type=int, +parser.add_argument("--eval_batch_size", default=56, type=int, help="batch size num.") parser.add_argument("--save_accuracy_path", default=None, help="Save accuracy results path.") @@ -60,12 +55,12 @@ "int4_fullrange", ] ) +parser.add_argument("--batch_size", default=8, type=int, + help="calibration batch size num.") parser.add_argument("--group_size", type=int, default=128) parser.add_argument("--scheme", default="sym") -parser.add_argument("--woq_enable_mse_search", action="store_true") parser.add_argument("--device", default="xpu") parser.add_argument("--compute_dtype", default="fp16") -parser.add_argument("--calib_iters", default=200, type=int, help="Calibration iters.") parser.add_argument("--load_in_4bit", type=bool, default=False) parser.add_argument("--load_in_8bit", type=bool, default=False) # ============GPTQ configs============== @@ -87,10 +82,10 @@ help="Block size. sub weight matrix size to run GPTQ.", ) parser.add_argument( - "--nsamples", type=int, default=512, help="Number of calibration data samples." + "--n_samples", type=int, default=512, help="Number of calibration data samples." ) parser.add_argument( - "--max_input_length", + "--seq_len", type=int, default=2048, help="Calibration dataset sequence max length, this should align with your model config", @@ -102,7 +97,7 @@ ) # ============AutoRound================== parser.add_argument( - "--calib_len", + "--autoround_iters", default=2048, type=int, help="Calibration dataset max or padding max length for AutoRound.", @@ -119,11 +114,17 @@ default=None, help="minmax learning rate, if None,it will beset to be the same with lr", ) +parser.add_argument("--autoround_iters", default=200, type=int, help="num iters for autoround calibration.") parser.add_argument( - "--use_quant_input", + "--disable_quanted_input", action="store_true", help="whether to use the output of quantized block to tune the next block", ) +parser.add_argument( + "--quant_lm_head", + action="store_true", + help="whether to quant the lm head layer", +) # ======================================= args = parser.parse_args() torch_dtype = convert_dtype_str2torch(args.compute_dtype) @@ -155,14 +156,14 @@ damp_percent=args.damp_percent, sym=True if args.scheme == "sym" else False, blocksize=args.blocksize, - nsamples=args.nsamples, + n_samples=args.n_samples, static_groups=args.static_groups, group_size=args.group_size, - max_input_length=args.max_input_length, + seq_len=args.seq_len, compute_dtype=args.compute_dtype, scale_dtype=args.compute_dtype, weight_dtype=args.weight_dtype, - calib_iters=args.calib_iters, + batch_size=args.batch_size, ) elif args.woq_algo.lower() == "autoround": quantization_config = AutoRoundConfig( @@ -171,16 +172,17 @@ bits=args.bits, sym=True if args.scheme == "sym" else False, group_size=args.group_size, - max_input_length=args.max_input_length, + seq_len=args.seq_len, compute_dtype=args.compute_dtype, scale_dtype=args.compute_dtype, weight_dtype=args.weight_dtype, - calib_iters=args.calib_iters, - calib_len=args.calib_len, - nsamples=args.nsamples, + iters=args.autoround_iters, + seq_len=args.seq_len, + n_samples=args.n_samples, lr=args.lr, minmax_lr=args.minmax_lr, - use_quant_input=args.use_quant_input, + disable_quanted_input=args.disable_quanted_input, + quant_lm_head = args.quant_lm_head, ) elif args.woq_algo.lower() == "rtn": quantization_config = RtnConfig( @@ -237,9 +239,9 @@ else: print("Disabled optimization with IPEX...") # start - num_iter = args.iters + num_iter = args.benchmark_iters num_warmup = args.num_warmup - prompt = [prompt] * args.batch_size + prompt = [prompt] * args.benchmark_batch_size amp_enabled = True amp_dtype = torch_dtype @@ -336,7 +338,7 @@ user_model = user_model, tasks = args.tasks, device = args.device, - batch_size = args.batch_size) + batch_size = args.eval_batch_size) results = evaluate(args) for task_name in args.tasks.split(","): if task_name == "wikitext":