Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
improve text-generation gpu example
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Jul 11, 2024
1 parent 79277b4 commit 745c3f2
Showing 1 changed file with 28 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,17 @@
"--num_beams", default=1, type=int, help="number of beams"
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--int8", action="store_true")
parser.add_argument(
"--int8_bf16_mixed",
action="store_true",
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
)
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--eval_batch_size", default=1, type=int,
help="batch size num.")
parser.add_argument("--do_profiling", action="store_true")
parser.add_argument("--profile_token_latency", action="store_true")
parser.add_argument("--iters", default=10, type=int, help="num iter")
parser.add_argument("--benchmark_iters", default=10, type=int, help="num iter")
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
# ============Accuracy configs==============
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--batch_size", default=1, type=int,
parser.add_argument("--eval_batch_size", default=56, type=int,
help="batch size num.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
Expand All @@ -60,12 +55,12 @@
"int4_fullrange",
]
)
parser.add_argument("--batch_size", default=8, type=int,
help="calibration batch size num.")
parser.add_argument("--group_size", type=int, default=128)
parser.add_argument("--scheme", default="sym")
parser.add_argument("--woq_enable_mse_search", action="store_true")
parser.add_argument("--device", default="xpu")
parser.add_argument("--compute_dtype", default="fp16")
parser.add_argument("--calib_iters", default=200, type=int, help="Calibration iters.")
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
# ============GPTQ configs==============
Expand All @@ -87,10 +82,10 @@
help="Block size. sub weight matrix size to run GPTQ.",
)
parser.add_argument(
"--nsamples", type=int, default=512, help="Number of calibration data samples."
"--n_samples", type=int, default=512, help="Number of calibration data samples."
)
parser.add_argument(
"--max_input_length",
"--seq_len",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
Expand All @@ -102,7 +97,7 @@
)
# ============AutoRound==================
parser.add_argument(
"--calib_len",
"--autoround_iters",
default=2048,
type=int,
help="Calibration dataset max or padding max length for AutoRound.",
Expand All @@ -119,11 +114,17 @@
default=None,
help="minmax learning rate, if None,it will beset to be the same with lr",
)
parser.add_argument("--autoround_iters", default=200, type=int, help="num iters for autoround calibration.")
parser.add_argument(
"--use_quant_input",
"--disable_quanted_input",
action="store_true",
help="whether to use the output of quantized block to tune the next block",
)
parser.add_argument(
"--quant_lm_head",
action="store_true",
help="whether to quant the lm head layer",
)
# =======================================
args = parser.parse_args()
torch_dtype = convert_dtype_str2torch(args.compute_dtype)
Expand Down Expand Up @@ -155,14 +156,14 @@
damp_percent=args.damp_percent,
sym=True if args.scheme == "sym" else False,
blocksize=args.blocksize,
nsamples=args.nsamples,
n_samples=args.n_samples,
static_groups=args.static_groups,
group_size=args.group_size,
max_input_length=args.max_input_length,
seq_len=args.seq_len,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
calib_iters=args.calib_iters,
batch_size=args.batch_size,
)
elif args.woq_algo.lower() == "autoround":
quantization_config = AutoRoundConfig(
Expand All @@ -171,16 +172,17 @@
bits=args.bits,
sym=True if args.scheme == "sym" else False,
group_size=args.group_size,
max_input_length=args.max_input_length,
seq_len=args.seq_len,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
calib_iters=args.calib_iters,
calib_len=args.calib_len,
nsamples=args.nsamples,
iters=args.autoround_iters,
seq_len=args.seq_len,
n_samples=args.n_samples,
lr=args.lr,
minmax_lr=args.minmax_lr,
use_quant_input=args.use_quant_input,
disable_quanted_input=args.disable_quanted_input,
quant_lm_head = args.quant_lm_head,
)
elif args.woq_algo.lower() == "rtn":
quantization_config = RtnConfig(
Expand Down Expand Up @@ -237,9 +239,9 @@
else:
print("Disabled optimization with IPEX...")
# start
num_iter = args.iters
num_iter = args.benchmark_iters
num_warmup = args.num_warmup
prompt = [prompt] * args.batch_size
prompt = [prompt] * args.benchmark_batch_size
amp_enabled = True
amp_dtype = torch_dtype

Expand Down Expand Up @@ -336,7 +338,7 @@
user_model = user_model,
tasks = args.tasks,
device = args.device,
batch_size = args.batch_size)
batch_size = args.eval_batch_size)
results = evaluate(args)
for task_name in args.tasks.split(","):
if task_name == "wikitext":
Expand Down

0 comments on commit 745c3f2

Please sign in to comment.