Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

h2o for kv cache compression #1468

Merged
merged 90 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
41d8647
h2o for kv cache compression
n1ck-guo Apr 10, 2024
eb7f564
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
c46ea7d
rebuild
BiaoFangAIA Apr 23, 2024
95ff9ae
merge
BiaoFangAIA Apr 23, 2024
9d27733
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
444490d
update
n1ck-guo Apr 25, 2024
4309089
update
n1ck-guo Apr 25, 2024
8c5272e
merge
n1ck-guo Apr 25, 2024
1b83e52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
3fd73cb
update
n1ck-guo May 7, 2024
a2d3ae0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
ddf5445
Merge branch 'main' into hengguo/h2o
VincyZhang May 13, 2024
91d4394
Merge branch 'main' into hengguo/h2o
n1ck-guo May 14, 2024
a83e6d6
real drop
n1ck-guo May 14, 2024
92c8a62
modify real drop code
n1ck-guo May 15, 2024
70a1cf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2024
bc9eade
fix
BiaoFangAIA May 16, 2024
9aa25f6
update for real drop and sim mode, using the same api
n1ck-guo May 16, 2024
03cdc8d
Merge branch 'main' into hengguo/h2o
n1ck-guo May 16, 2024
e51b5b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
b8e9df2
support for sdpa and flash attention
n1ck-guo May 16, 2024
02f31b2
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 16, 2024
274b7ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
877329d
change to new api
n1ck-guo May 17, 2024
b435e3e
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 17, 2024
5068552
clean code
n1ck-guo May 17, 2024
5e5f589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2024
d0dce7d
fix
n1ck-guo May 20, 2024
febb76a
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 20, 2024
24c4725
add example
n1ck-guo May 20, 2024
5bd3f16
Merge branch 'main' into hengguo/h2o
n1ck-guo May 20, 2024
955e132
clean
n1ck-guo May 20, 2024
91efe57
pylint
n1ck-guo May 20, 2024
4190edb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2024
e71bf92
pylint
n1ck-guo May 20, 2024
41f016c
pylint
n1ck-guo May 20, 2024
d49487f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2024
9ac5eca
fix import error
n1ck-guo May 21, 2024
09def0b
update
n1ck-guo May 21, 2024
5cae1fd
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo May 21, 2024
3042dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
3a992ab
pylint
n1ck-guo May 21, 2024
8c89cbc
pylint
n1ck-guo May 21, 2024
072ad76
merge
n1ck-guo May 21, 2024
4c26487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
741c7cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
0800df2
add example readme
n1ck-guo May 27, 2024
558dfd9
update
n1ck-guo May 27, 2024
d6de2b3
merge
n1ck-guo May 27, 2024
693983f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
30bed25
fix acc bug
n1ck-guo Jun 6, 2024
f8a64fc
fix
n1ck-guo Jun 7, 2024
d892e74
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jun 11, 2024
7a12ec6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
124bb72
refactor code
n1ck-guo Jun 18, 2024
5181afc
fix
n1ck-guo Jun 18, 2024
93ad39b
merge
n1ck-guo Jun 18, 2024
58bfcd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2024
76656c9
Merge branch 'main' into hengguo/h2o
n1ck-guo Jun 18, 2024
91c5f3c
new api
n1ck-guo Jun 20, 2024
4884b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2024
812a838
support for gaudi
n1ck-guo Jun 24, 2024
2cc6a8f
merge
n1ck-guo Jun 24, 2024
2d82bb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
a9488b4
update
n1ck-guo Jun 25, 2024
3c185ad
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jun 25, 2024
a6d3fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2024
9698230
pylint
n1ck-guo Jun 26, 2024
14f5a6d
pylint
n1ck-guo Jun 27, 2024
dd6ee3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
523ca76
Merge branch 'main' into hengguo/h2o
n1ck-guo Jun 28, 2024
2618e6f
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 2, 2024
cedcd43
Merge branch 'main' into hengguo/h2o
changwangss Jul 3, 2024
eb8441c
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 12, 2024
0894b6d
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 15, 2024
d241c25
add desc to h2o in readme
n1ck-guo Jul 15, 2024
0c547c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
3723158
add doc for h2o
n1ck-guo Jul 15, 2024
7da0cf5
Merge branch 'hengguo/h2o' of https://github.com/intel/intel-extensio…
n1ck-guo Jul 15, 2024
d600112
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
b1ab771
update
n1ck-guo Jul 16, 2024
4bee0f0
add ut
n1ck-guo Jul 16, 2024
3fad641
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
46fa4d7
Merge branch 'main' into hengguo/h2o
XuehaoSun Jul 16, 2024
fda316d
fix
n1ck-guo Jul 17, 2024
268fb80
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 22, 2024
b7a1bd1
update
n1ck-guo Jul 22, 2024
2ae387a
merge
n1ck-guo Jul 22, 2024
2706d5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2024
b2d5768
Merge branch 'main' into hengguo/h2o
n1ck-guo Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ omit =
*/intel_extension_for_transformers/langchain/**
*/intel_extension_for_transformers/llama_index/**
*/intel_extension_for_transformers/transformers/utils/get_throughput.py
*/intel_extension_for_transformers/transformers/kv_cache_compression/**
exclude_lines =
pragma: no cover
raise NotImplementedError
Expand Down
49 changes: 49 additions & 0 deletions docs/h2o.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models
1. [Introduction](#introduction)
2. [Usage](#usage)

## Introduction
**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint.

This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy.

H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy.


For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048).


![](./imgs/h2o.png)


## Usage
Using simulation mode
```python
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
h2o_config = H2OConfig(
heavy_ratio=heavy_ratio,
recent_ratio=recent_ratio,
h2o_min_seqlen=h2o_min_seqlen,
real_drop=False,
)
user_model = LlamaForCausalLM.from_pretrained(
args.model,
prune_config=h2o_config,
trust_remote_code=args.trust_remote_code)
```
To run the real_drop mode
```python
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
h2o_config = H2OConfig(
heavy_ratio=heavy_ratio,
recent_ratio=recent_ratio,
h2o_min_seqlen=h2o_min_seqlen,
real_drop=True,
)
user_model = LlamaForCausalLM.from_pretrained(
args.model,
prune_config=h2o_config,
trust_remote_code=args.trust_remote_code)
```

Please refer to [h2o example](../examples/huggingface/pytorch/text-generation/h2o/run_generation.py) for the details.
Binary file added docs/imgs/h2o.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 47 additions & 0 deletions examples/huggingface/pytorch/text-generation/h2o/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models

**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint.

This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy.

H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy.


For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048).


![](./imgs/1.png)


## Usage and Examples
### Evaluation on tasks from [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework
Using simulation mode
```bash
python run_generation.py \
--model meta-llama/Meta-Llama-3-8B \
--accuracy \
--batch_size 16 \
--h2o \
--heavy_ratio 0.1 \
--recent_ratio 0.1 \
--device 0
```
To run the real_drop mode
```bash
python run_generation.py \
--model meta-llama/Meta-Llama-3-8B \
--accuracy \
--batch_size 16 \
--h2o \
--heavy_ratio 0.1 \
--recent_ratio 0.1 \
--device 0
--real_drop
```
Get the accuracy of dense model
```bash
python run_generation.py \
--model meta-llama/Meta-Llama-3-8B \
--accuracy \
--batch_size 16
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
238 changes: 238 additions & 0 deletions examples/huggingface/pytorch/text-generation/h2o/run_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import argparse
import sys
import time
import json
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.utils import check_min_version

parser = argparse.ArgumentParser()
parser.add_argument("--model", default=None)
parser.add_argument(
"--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k"
)
parser.add_argument(
"--max_new_tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--int8", action="store_true")
parser.add_argument(
"--int8_bf16_mixed",
action="store_true",
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
)
parser.add_argument(
"--restore",
action="store_true",
help="restore ipex quantized model from output_dir/best_configure.json",
)
parser.add_argument(
"--peft_model_id", type=str, default=None, help="model_name_or_path of peft model"
)
parser.add_argument("--_commit_hash", default=None, type=str)
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument("--use_neural_speed", action="store_true")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--iters", default=100, type=int, help="num iter")
parser.add_argument("--num_warmup", default=10, type=int, help="num warmup")
# ============Accuracy configs==============
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--batch_size", default=16, type=int, help="batch size num.")
parser.add_argument(
"--save_accuracy_path", default=None, help="Save accuracy results path."
)
parser.add_argument("--output_excel", default=None, type=str)
parser.add_argument("--eval_bs", default=4, type=int,
help="eval batch size")
parser.add_argument("--tasks", nargs='+', default=["winogrande", "copa", "piqa", "rte", "hellaswag", \
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"], type=str, \
help="tasks list for accuracy validation")
parser.add_argument("--num_fewshot", default=0, type=int, help="num few shot.")
# ============MixedPrecision configs==============
parser.add_argument("--mixed_precision", action="store_true")

# ============h2o configs==============
parser.add_argument('--h2o', action='store_true')
parser.add_argument('--is_gen', action='store_true')
parser.add_argument('--real_drop', action='store_true')
parser.add_argument("--heavy_ratio", type=float, default=0.1)
parser.add_argument("--recent_ratio", type=float, default=0.1)
parser.add_argument("--device", type=str, default='cpu')
parser.add_argument("--h2o_min_seqlen", type=int, default=0)

args = parser.parse_args()
# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
# 4.31.0 for ipex.optimize_transformers
# get model config
if args.peft_model_id:
from peft import PeftConfig

peft_config = PeftConfig.from_pretrained(args.peft_model_id)
if args.model is None:
args.model = peft_config.base_model_name_or_path
print("we will use peft base_model_name_or_path to get tokenizer.")

config = AutoConfig.from_pretrained(
args.model,
torchscript=False,
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
_commit_hash=args._commit_hash,
)

# chatglm
if config.model_type == "chatglm":
AutoModelForCausalLM = AutoModel
# tokenizer
if config.model_type == "llama":
from transformers import LlamaTokenizer

# tokenizer = LlamaTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code
)

# use peft
args.model = args.peft_model_id if args.peft_model_id is not None else args.model

# Generation
if args.use_neural_speed:
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=1)
else:
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)

if 'cpu' in args.device:
device = args.device
else:
device = f"cuda:{args.device}"

# get optimized model
if args.h2o:
print('Enable Small Cache Size')
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
h2o_config = H2OConfig(
heavy_ratio=args.heavy_ratio,
recent_ratio=args.recent_ratio,
h2o_min_seqlen=args.h2o_min_seqlen,
real_drop=args.real_drop,
mean=False,
)
user_model = LlamaForCausalLM.from_pretrained(
args.model,
prune_config=h2o_config,
trust_remote_code=args.trust_remote_code)
print("converted model: ", user_model)
else:
user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
user_model.to(device)

# save model
# if args.output_dir is not None:
# tokenizer.save_pretrained(args.output_dir)
# user_model.save_pretrained(args.output_dir)

if args.benchmark:
user_model = (
user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) else user_model
)
prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)

# start
total_time = 0.0
num_iter = args.iters
num_warmup = args.num_warmup
total_token_num = 0
eos_token_id = tokenizer.eos_token_id
with torch.inference_mode(), torch.no_grad():
for i in range(num_iter):
tic = time.time()
if hasattr(tokenizer, "build_chat_input"):
input_ids = tokenizer.build_chat_input(prompt)["input_ids"]
input_ids = input_ids.repeat(args.batch_size, 1)
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>"),
]
elif hasattr(tokenizer, "build_prompt"):
build_prompt = tokenizer.build_prompt(prompt)
input_ids = tokenizer(
[build_prompt] * args.batch_size, return_tensors="pt"
).input_ids
else:
input_ids = tokenizer(
[prompt] * args.batch_size, return_tensors="pt"
).input_ids
gen_ids = user_model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
**generate_kwargs,
eos_token_id=eos_token_id
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
toc = time.time()
# please check the gen_ids if include input_ids.
input_tokens_num = input_ids.numel()
output_tokens_num = torch.tensor(gen_ids).numel() - input_tokens_num
print(gen_text, flush=True)
if i >= num_warmup:
total_time += toc - tic
total_token_num += output_tokens_num

print("\n", "-" * 10, "Summary:", "-" * 10)
latency = total_time / total_token_num
print("Inference latency: %.3f sec." % latency)
throughput = total_token_num / total_time
print("Throughput: {} samples/sec".format(throughput))

if args.accuracy:
user_model = (user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) \
else user_model)
# from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
# model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code)
# args.tasks = ",".join(args.tasks)
# tokenizer.pad_token = tokenizer.eos_token
# eval_args = LMEvalParser(model = "hf",
# user_model=user_model,
# tokenizer=tokenizer,
# model_args=model_args,
# tasks = args.tasks,
# device = device,
# num_fewshot=args.num_fewshot,
# output_path=args.save_accuracy_path,
# batch_size = args.batch_size)
# print("using device:", device)
# results = evaluate(eval_args)


# original lm_eval
from lm_eval.evaluator import simple_evaluate
from lm_eval.tasks import TaskManager
import lm_eval

verbosity = 'INFO'
task_manager = TaskManager(verbosity)
limit = None
cache_requests = False
lm = lm_eval.api.registry.get_model("hf")(
pretrained=user_model,
batch_size=args.batch_size,
max_batch_size=None,
)
model_args="pretrained="+ args.model+ ",tokenizer="+ args.model + ",dtype=float32"
use_cache = None
results = simple_evaluate(
model=lm,
model_args=model_args,
tasks=args.tasks,
num_fewshot=args.num_fewshot,
device=device
)
import pprint
pprint.pprint(results["results"])
Loading
Loading