From 02a69843005e8b530b3eb1d04912759b178c8f90 Mon Sep 17 00:00:00 2001 From: "Dong, Bo" Date: Wed, 3 Apr 2024 13:59:15 +0800 Subject: [PATCH] Update run_generation_gpu_woq.py (#1454) --- .../quantization/run_generation_gpu_woq.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index 8fe65721bc1..09ef3869e19 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -103,6 +103,23 @@ type=int, help="Calibration dataset max or padding max length for AutoRound.", ) +parser.add_argument( + "--lr", + type=float, + default=0.0025, + help="learning rate, if None, it will be set to 1.0/iters automatically", +) +parser.add_argument( + "--minmax_lr", + type=float, + default=0.0025, + help="minmax learning rate, if None,it will beset to be the same with lr", +) +parser.add_argument( + "--use_quant_input", + action="store_true", + help="whether to use the output of quantized block to tune the next block", +) # ======================================= args = parser.parse_args() torch_dtype = convert_dtype_str2torch(args.compute_dtype) @@ -162,6 +179,9 @@ calib_iters=args.calib_iters, calib_len=args.calib_len, nsamples=args.nsamples, + lr=args.lr, + minmax_lr=args.minmax_lr, + use_quant_input=args.use_quant_input, ) elif args.woq_algo.lower() == "rtn": quantization_config = RtnConfig(