Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Update run_generation_gpu_woq.py (#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
a32543254 authored Apr 3, 2024
1 parent d6e6e9f commit 02a6984
Showing 1 changed file with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@
type=int,
help="Calibration dataset max or padding max length for AutoRound.",
)
parser.add_argument(
"--lr",
type=float,
default=0.0025,
help="learning rate, if None, it will be set to 1.0/iters automatically",
)
parser.add_argument(
"--minmax_lr",
type=float,
default=0.0025,
help="minmax learning rate, if None,it will beset to be the same with lr",
)
parser.add_argument(
"--use_quant_input",
action="store_true",
help="whether to use the output of quantized block to tune the next block",
)
# =======================================
args = parser.parse_args()
torch_dtype = convert_dtype_str2torch(args.compute_dtype)
Expand Down Expand Up @@ -162,6 +179,9 @@
calib_iters=args.calib_iters,
calib_len=args.calib_len,
nsamples=args.nsamples,
lr=args.lr,
minmax_lr=args.minmax_lr,
use_quant_input=args.use_quant_input,
)
elif args.woq_algo.lower() == "rtn":
quantization_config = RtnConfig(
Expand Down

0 comments on commit 02a6984

Please sign in to comment.