From 0c19eecc38992a05506aa05573410fd240bf03af Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Fri, 20 Oct 2023 14:17:08 +0800 Subject: [PATCH] Add Speculative sampling support (#328) * update readme * add speculative sample * update inference scripts about speculative sampling * update README.md * update readme * update readme * update readme * update readme * update HF links * Update speculative_sample.py * Update gradio_demo.py * Update README.md * Update README_EN.md * Update speculative_sample.py * Update speculative_sample.py * Update speculative_sample.py * Update speculative_sample.py * Update gradio_demo.py * fix bugs in speculative sampling --------- Co-authored-by: GoGoJoestar Co-authored-by: GoGoJoestar <58219543+GoGoJoestar@users.noreply.github.com> --- README.md | 40 +- README_EN.md | 42 +- scripts/attn_and_long_ctx_patches.py | 27 +- scripts/inference/gradio_demo.py | 97 ++++- scripts/inference/inference_hf.py | 166 ++++++-- scripts/inference/speculative_sample.py | 486 ++++++++++++++++++++++++ 6 files changed, 790 insertions(+), 68 deletions(-) create mode 100644 scripts/inference/speculative_sample.py diff --git a/README.md b/README.md index 890756d..2f0fd68 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ #### 已开源的模型 -- 基座模型:Chinese-LLaMA-2-7B, Chinese-LLaMA-2-13B -- 聊天模型:Chinese-Alpaca-2-7B, Chinese-Alpaca-2-13B +- 基座模型:Chinese-LLaMA-2-1.3B, Chinese-LLaMA-2-7B, Chinese-LLaMA-2-13B +- 聊天模型:Chinese-Alpaca-2-1.3B, Chinese-Alpaca-2-7B, Chinese-Alpaca-2-13B - 长上下文模型:Chinese-LLaMA-2-7B-16K, Chinese-LLaMA-2-13B-16K, Chinese-Alpaca-2-7B-16K, Chinese-Alpaca-2-13B-16K ![](./pics/screencast.gif) @@ -101,9 +101,9 @@ | 对比项 | 中文LLaMA-2 | 中文Alpaca-2 | | :-------------------- | :----------------------------------------------------: | :----------------------------------------------------------: | | 模型类型 | **基座模型** | **指令/Chat模型(类ChatGPT)** | -| 已开源大小 | 7B、13B | 7B、13B | +| 已开源大小 | 1.3B、7B、13B | 1.3B、7B、13B | | 训练类型 | Causal-LM (CLM) | 指令精调 | -| 训练方式 | LoRA + 全量emb/lm-head | LoRA + 全量emb/lm-head | +| 训练方式 | 7B、13B:LoRA + 全量emb/lm-head
1.3B:全量 | 7B、13B:LoRA + 全量emb/lm-head
1.3B:全量 | | 基于什么模型训练 | [原版Llama-2](https://github.com/facebookresearch/llama)(非chat版) | 中文LLaMA-2 | | 训练语料 | 无标注通用语料(120G纯文本) | 有标注指令数据(500万条) | | 词表大小[1] | 55,296 | 55,296 | @@ -116,6 +116,7 @@ > [1] *本项目一代模型和二代模型的词表不同,请勿混用。二代LLaMA和Alpaca的词表相同。*
> [2] *括号内表示基于NTK上下文扩展支持的最大长度。*
> [3] *Alpaca-2采用了Llama-2-chat系列模板(格式相同,提示语不同),而不是一代Alpaca的模板,请勿混用。*
+> [4] *不建议单独使用1.3B模型,而是通过投机采样搭配更大的模型(7B、13B)使用。*
### 完整模型下载 @@ -125,8 +126,10 @@ | :------------------------ | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | Chinese-LLaMA-2-13B | 基座模型 | 24.7 GB | [[百度]](https://pan.baidu.com/s/1T3RqEUSmyg6ZuBwMhwSmoQ?pwd=e9qy) [[Google]](https://drive.google.com/drive/folders/1YNa5qJ0x59OEOI7tNODxea-1YvMPoH05?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-13b) | | Chinese-LLaMA-2-7B | 基座模型 | 12.9 GB | [[百度]](https://pan.baidu.com/s/1E5NI3nlQpx1j8z3eIzbIlg?pwd=n8k3) [[Google]](https://drive.google.com/drive/folders/18pp4I-mvQxRA7b8vF9gP-2cH_ocnXVKh?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-7b) | +| Chinese-LLaMA-2-1.3B | 基座模型 | 2.4 GB | [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-1.3b) | | Chinese-Alpaca-2-13B | 指令模型 | 24.7 GB | [[百度]](https://pan.baidu.com/s/1MT_Zlap1OtdYMgoBNTS3dg?pwd=9xja) [[Google]](https://drive.google.com/drive/folders/1MTsKlzR61xmbTR4hBWzQas_MOpUZsogN?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-13b) | | Chinese-Alpaca-2-7B | 指令模型 | 12.9 GB | [[百度]](https://pan.baidu.com/s/1wxx-CdgbMupXVRBcaN4Slw?pwd=kpn9) [[Google]](https://drive.google.com/drive/folders/1JsJDVs7tE2y31PBNleBlDPsB7S0ZrY8d?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-7b) | +| Chinese-Alpaca-2-1.3B | 指令模型 | 2.4 GB | [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-1.3b) | 以下是长上下文版模型,**推荐以长文本为主的下游任务使用**,否则建议使用上述标准版。 @@ -172,15 +175,15 @@ 本项目中的相关模型主要支持以下量化、推理和部署方式,具体内容请参考对应教程。 -| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | vLLM§ | 16K | 教程 | -| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | -| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) | -| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) | -| [**Colab Demo**](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | 在Colab中启动交互界面 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | -| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) | -| [**text-generation-webui**](https://github.com/oobabooga/text-generation-webui) | 前端Web UI界面的部署方式 | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/text-generation-webui_zh) | -| [**LangChain**](https://github.com/hwchase17/langchain) | 适合二次开发的大模型应用开源框架 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/langchain_zh) | -| [**privateGPT**](https://github.com/imartinez/privateGPT) | 基于LangChain的多文档本地问答框架 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_zh) | +| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | vLLM§ | 16K | 投机采样 | 教程 | +| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | +| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) | +| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) | +| [**Colab Demo**](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | 在Colab中启动交互界面 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | +| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) | +| [**text-generation-webui**](https://github.com/oobabooga/text-generation-webui) | 前端Web UI界面的部署方式 | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/text-generation-webui_zh) | +| [**LangChain**](https://github.com/hwchase17/langchain) | 适合二次开发的大模型应用开源框架 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/langchain_zh) | +| [**privateGPT**](https://github.com/imartinez/privateGPT) | 基于LangChain的多文档本地问答框架 | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_zh) | > [!NOTE] > 工具支持该特性,但教程中未实现,详细说明请参考对应官方文档
@@ -273,6 +276,17 @@ | CPU Speed | 117 | 42 | 51 | 39 | 44 | 43 | 48 | 51 | 50 | 54 | 65 | | GPU Speed | 53 | 19 | 21 | 17 | 18 | 20 | x | x | 25 | 26 | x | +### 投机采样加速效果评测 + +通过投机采样方法并借助Chinese-LLaMA-2-1.3B和Chinese-Alpaca-2-1.3B,可以分别加速7B、13B的LLaMA和Alpaca模型的推理速度。以下是使用[投机采样脚本](scripts/inference/speculative_sample.py)在1*A40-48G上解码[生成效果评测](#生成效果评测)中的问题测得的平均速度(速度以ms/token计,模型均为fp16精度),供用户参考。详细说明见[📖GitHub Wiki](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh#投机采样解码)。 + +| 草稿模型 | 草稿模型速度 | 目标模型 | 目标模型速度 | 投机采样速度(加速比) | +| :---------- | :-----------------: | :----------- | :-----------------: | :--------: | +| Chinese-LLaMA-2-1.3B | 7.6 | Chinese-LLaMA-2-7B | 49.3 | 36.0(1.37x) | +| Chinese-LLaMA-2-1.3B | 7.6 | Chinese-LLaMA-2-13B | 66.0 | 47.1(1.40x) | +| Chinese-Alpaca-2-1.3B | 8.1 | Chinese-Alpaca-2-7B | 50.2 | 34.9(1.44x) | +| Chinese-Alpaca-2-1.3B | 8.2 | Chinese-Alpaca-2-13B | 67.0 | 41.6(1.61x) | + ## 训练与精调 ### 预训练 diff --git a/README_EN.md b/README_EN.md index 89ab6db..ed62d99 100644 --- a/README_EN.md +++ b/README_EN.md @@ -24,8 +24,8 @@ This project is based on the Llama-2, released by Meta, and it is the second gen #### Open-sourced Models -- Base model: Chinese-LLaMA-2-7B, Chinese-LLaMA-2-13B -- Instruction/chat model: Chinese-Alpaca-2-7B, Chinese-Alpaca-2-13B +- Base model: Chinese-LLaMA-2-1.3B, Chinese-LLaMA-2-7B, Chinese-LLaMA-2-13B +- Instruction/chat model: Chinese-Alpaca-2-1.3B, Chinese-Alpaca-2-7B, Chinese-Alpaca-2-13B - Long context model: Chinese-LLaMA-2-7B-16K, Chinese-LLaMA-2-13B-16K, Chinese-Alpaca-2-7B-16K, Chinese-Alpaca-2-13B-16K ![](./pics/screencast.gif) @@ -97,9 +97,9 @@ Below is a basic comparison between the Chinese LLaMA-2 and Alpaca-2 models, as | Comparison | Chinese LLaMA-2 | Chinese Alpaca-2 | | :---------------------------- | :----------------------------------------------------------: | :----------------------------------------------------------: | | Model Type | **Base Model** | **Instruction/Chat Model (like ChatGPT)** | -| Released Sizes | 7B, 13B | 7B, 13B | +| Released Sizes | 1.3B, 7B, 13B | 1.3B, 7B, 13B | | Training Method | Causal-LM (CLM) | Instruction fine-tuning | -| Training Parts | LoRA + emb/lm-head | LoRA + emb/lm-head | +| Training Parts | 7B, 13B: LoRA + emb/lm-head
1.3B: full params | 7B, 13B: LoRA + emb/lm-head
1.3B: full params | | Trained on | [Original Llama-2](https://github.com/facebookresearch/llama) (non-chat) | Chinese LLaMA-2 | | Training Corpus | Unlabeled general corpus (120G raw text) | Labeled instruction data (5M samples) | | Vocabulary Size[1] | 55,296 | 55,296 | @@ -112,6 +112,7 @@ Below is a basic comparison between the Chinese LLaMA-2 and Alpaca-2 models, as > [1] *The vocabulary of the first and second generation models in this project are different, do not mix them. The vocabularies of the second generation LLaMA and Alpaca are the same.*
> [2] *Extended context size with NTK method is depicted in brackets.*
> [3] *Alpaca-2 uses the Llama-2-chat series templates (different prompts), not the templates of the first-generation Alpaca, do not mix them.*
+> [4] *1.3B models are not intended for standalone use; instead, use it together with larger models (7B, 13B) through speculative sampling.*
### Full Model Download @@ -121,8 +122,10 @@ Below are the full models, which can be used directly afterwards, without additi | :-------------------- | :---------------: | :-----: | :----------------------------------------------------------: | | Chinese-LLaMA-2-13B | Base model | 24.7 GB | [[Baidu]](https://pan.baidu.com/s/1T3RqEUSmyg6ZuBwMhwSmoQ?pwd=e9qy) [[Google]](https://drive.google.com/drive/folders/1YNa5qJ0x59OEOI7tNODxea-1YvMPoH05?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-13b) | | Chinese-LLaMA-2-7B | Base model | 12.9 GB | [[Baidu]](https://pan.baidu.com/s/1E5NI3nlQpx1j8z3eIzbIlg?pwd=n8k3) [[Google]](https://drive.google.com/drive/folders/18pp4I-mvQxRA7b8vF9gP-2cH_ocnXVKh?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-7b) | +| Chinese-LLaMA-2-1.3B | Base model | 2.4 GB | [[🤗HF]](https://huggingface.co/ziqingyang/chinese-llama-2-1.3b) | | Chinese-Alpaca-2-13B | Chat Model | 24.7 GB | [[Baidu]](https://pan.baidu.com/s/1MT_Zlap1OtdYMgoBNTS3dg?pwd=9xja) [[Google]](https://drive.google.com/drive/folders/1MTsKlzR61xmbTR4hBWzQas_MOpUZsogN?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-13b) | | Chinese-Alpaca-2-7B | Chat Model | 12.9 GB | [[Baidu]](https://pan.baidu.com/s/1wxx-CdgbMupXVRBcaN4Slw?pwd=kpn9) [[Google]](https://drive.google.com/drive/folders/1JsJDVs7tE2y31PBNleBlDPsB7S0ZrY8d?usp=share_link) [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-7b) | +| Chinese-Alpaca-2-1.3B | Chat model | 2.4 GB | [[🤗HF]](https://huggingface.co/ziqingyang/chinese-alpaca-2-1.3b) | The followings are long context models, which are recommended for long context tasks. @@ -167,15 +170,15 @@ The followings are long context models, which are recommended for long context t The models in this project mainly support the following quantization, inference, and deployment methods. -| Tool | Features | CPU | GPU | Quant | GUI | API | vLLM§ | 16K | Tutorial | -| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | -| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) | -| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) | -| [**Colab Demo**](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | Running a Gradio web demo in Colab | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | -| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) | -| [**text-generation-webui**](https://github.com/oobabooga/text-generation-webui) | A tool for deploying model as a web UI | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/text-generation-webui_en) | -| [**LangChain**](https://github.com/hwchase17/langchain) | LLM application development framework, suitable for secondary development | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/langchain_en) | -| [**privateGPT**](https://github.com/imartinez/privateGPT) | LangChain-based multi-document QA framework | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en) | +| Tool | Features | CPU | GPU | Quant | GUI | API | vLLM§ | 16K | Speculative Sampling | Tutorial | +| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | ------------------------------------------------------------ | +| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) | +| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) | +| [**Colab Demo**](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | Running a Gradio web demo in Colab | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | [link](https://colab.research.google.com/drive/1yu0eZ3a66by8Zqm883LLtRQrguBAb9MR?usp=sharing) | +| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) | +| [**text-generation-webui**](https://github.com/oobabooga/text-generation-webui) | A tool for deploying model as a web UI | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/text-generation-webui_en) | +| [**LangChain**](https://github.com/hwchase17/langchain) | LLM application development framework, suitable for secondary development | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/langchain_en) | +| [**privateGPT**](https://github.com/imartinez/privateGPT) | LangChain-based multi-document QA framework | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en) | > [!NOTE] > : Supported by this tool, but not implemented in the tutorial. Please refer to the official documentation for details.
@@ -264,6 +267,19 @@ Specifically, the followings are the benchmark for different quantization method | CPU Speed | 117 | 42 | 51 | 39 | 44 | 43 | 48 | 51 | 50 | 54 | 65 | | GPU Speed | 53 | 19 | 21 | 17 | 18 | 20 | x | x | 25 | 26 | x | +### Speculative Sampling Evaluation + +Using speculative sampling and leveraging Chinese-LLaMA-2-1.3B and Chinese-Alpaca-2-1.3B can accelerate the inference speed of 7B and 13B LLaMA and Alpaca models. The followings are the inference speeds (ms/token) evaluated on the questions in [Generation Performance Evaluation](#Generation-Performance-Evaluation) on 1*A40-48G. All the models are in fp16 format. For details, see our [Wiki](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en#Speculative-Sampling). + +| Draft Model | Draft Model Speed | Target Model | Target Model Speed | Speculative Sampling Speed | +| :-------------------- | :---------------: | :------------------- | :----------------: | :------------------------: | +| Chinese-LLaMA-2-1.3B | 7.6 | Chinese-LLaMA-2-7B | 49.3 | 36.0(1.37x) | +| Chinese-LLaMA-2-1.3B | 7.6 | Chinese-LLaMA-2-13B | 66.0 | 47.1(1.40x) | +| Chinese-Alpaca-2-1.3B | 8.1 | Chinese-Alpaca-2-7B | 50.2 | 34.9(1.44x) | +| Chinese-Alpaca-2-1.3B | 8.2 | Chinese-Alpaca-2-13B | 67.0 | 41.6(1.61x) | + + + ## Training and Fine-tuning Please refer to the corresponding Wiki for information on pre-training (Chinese LLaMA-2 training) and instruction fine-tuning (Chinese Alpaca-2 training). diff --git a/scripts/attn_and_long_ctx_patches.py b/scripts/attn_and_long_ctx_patches.py index dcebb8c..d971923 100644 --- a/scripts/attn_and_long_ctx_patches.py +++ b/scripts/attn_and_long_ctx_patches.py @@ -47,8 +47,10 @@ def xformers_forward( value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + past_kv_len = 0 if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + past_kv_len = past_key_value[0].shape[-2] + kv_seq_len += past_kv_len if STORE_KV_BEFORE_ROPE is False: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -75,12 +77,31 @@ def xformers_forward( position_ids = position_ids.unsqueeze(0).view(-1, kv_seq_len) key_states = apply_rotary_pos_emb_single(key_states, cos, sin, position_ids) + pad_query = False if xops is not None and USE_MEM_EFF_ATTENTION: attn_weights = None query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_bias = None if (query_states.size(1)==1 and key_states.size(1)>1) else xops.LowerTriangularMask() + if query_states.size(1)==1 and key_states.size(1)>1: + attn_bias = None + elif query_states.size(1)1 and past_kv_len > 0: + attn_bias = xops.LowerTriangularMask() + query_states = torch.cat( + ( + torch.full( + (bsz, past_kv_len, self.num_heads, self.head_dim), + 0.0, + dtype=query_states.dtype, + device=query_states.device, + ), + query_states, + ), + dim=1, + ) + pad_query = True + else: + attn_bias = xops.LowerTriangularMask() attn_output = xops.memory_efficient_attention( query_states, key_states, value_states, attn_bias=attn_bias, p=0) else: @@ -113,6 +134,8 @@ def xformers_forward( ) attn_output = attn_output.transpose(1, 2) + if pad_query: + attn_output = attn_output[:,past_kv_len:] attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index 474ca07..fafddee 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -3,7 +3,8 @@ LlamaForCausalLM, LlamaTokenizer, StoppingCriteria, - BitsAndBytesConfig + BitsAndBytesConfig, + GenerationConfig ) import gradio as gr import argparse @@ -87,6 +88,29 @@ type=int, default=8000, help="Port of vLLM service.") +parser.add_argument( + "--speculative_sampling", + action='store_true', + help="Use speculative sampling to speed up inference.") +parser.add_argument( + "--draft_base_model", + default=None, + type=str, + help="Draft base model used in speculative sampling.") +parser.add_argument( + "--draft_lora_model", + default=None, + type=str, + help="If None, perform inference on the draft base model") +parser.add_argument( + "--draft_model_load_in_8bit", + action='store_true', + help="Load the draft model in the 8bit mode") +parser.add_argument( + "--draft_model_load_in_4bit", + action='store_true', + help="Load the draft model in the 4bit mode") + args = parser.parse_args() ENABLE_CFG_SAMPLING = True @@ -112,6 +136,12 @@ if not args.only_cpu: apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) +if args.speculative_sampling: + if args.draft_base_model == None: + raise ValueError("Speculative sampling requires a draft model. Please specify the draft model.") + if args.draft_model_load_in_8bit and args.draft_model_load_in_4bit: + raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") + from speculative_sample import speculative_sample # Set CUDA devices if available os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus @@ -125,11 +155,13 @@ def setup(): global tokenizer, model, device, share, port, max_memory + if args.speculative_sampling: + global draft_model if args.use_vllm: # global share, port, max_memory max_memory = args.max_memory port = args.port - share = args.share + share = args.share == 'True' or args.share is True if args.lora_model is not None: raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.") @@ -137,6 +169,8 @@ def setup(): raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.") if args.only_cpu: raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.") + if args.speculative_sampling: + raise ValueError("speculative_sampling is set, but vLLM does not support speculative sampling. Please unset speculative_sampling. ") if args.tokenizer_path is None: args.tokenizer_path = args.base_model @@ -155,7 +189,7 @@ def setup(): else: max_memory = args.max_memory port = args.port - share = args.share + share = args.share == 'True' or args.share is True load_type = torch.float16 if torch.cuda.is_available(): device = torch.device(0) @@ -183,6 +217,23 @@ def setup(): quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None ) + if args.speculative_sampling: + if args.load_in_4bit or args.load_in_8bit: + draft_quantization_config = BitsAndBytesConfig( + load_in_4bit=args.draft_model_load_in_4bit, + load_in_8bit=args.draft_model_load_in_8bit, + bnb_4bit_compute_dtype=load_type, + ) + draft_base_model = LlamaForCausalLM.from_pretrained( + args.draft_base_model, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + load_in_4bit=args.draft_model_load_in_4bit, + load_in_8bit=args.draft_model_load_in_8bit, + quantization_config=draft_quantization_config if (args.draft_model_load_in_4bit or args.draft_model_load_in_8bit) else None + ) + model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenizer_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") @@ -190,6 +241,12 @@ def setup(): if model_vocab_size != tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenizer_vocab_size) + if args.speculative_sampling: + draft_model_vocab_size = draft_base_model.get_input_embeddings().weight.size(0) + print(f"Vocab of the draft base model: {draft_model_vocab_size}") + if draft_model_vocab_size!=tokenizer_vocab_size: + print("Resize draft model embeddings to fit tokenizer") + draft_base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") model = PeftModel.from_pretrained( @@ -200,11 +257,20 @@ def setup(): ).half() else: model = base_model + if args.speculative_sampling: + if args.draft_lora_model is not None: + print("loading peft draft model") + draft_model = PeftModel.from_pretrained(draft_base_model, args.draft_lora_model,torch_dtype=load_type,device_map='auto',).half() + else: + draft_model = draft_base_model if device == torch.device('cpu'): model.float() - model.eval() + if args.speculative_sampling: + if device==torch.device('cpu'): + draft_model.float() + draft_model.eval() # Reset the user input @@ -359,6 +425,7 @@ def predict( repetition_penalty=1.1, guidance_scale=1.0, presence_penalty=0.0, + draft_k=0, ): if len(system_prompt) == 0: system_prompt = DEFAULT_SYSTEM_PROMPT @@ -431,11 +498,17 @@ def predict( 'top_k': top_k, 'do_sample': do_sample, 'repetition_penalty': repetition_penalty, + 'eos_token_id': tokenizer.eos_token_id, } if ENABLE_CFG_SAMPLING is True: generate_params['guidance_scale'] = guidance_scale generate_params['negative_prompt_ids'] = negative_prompt_ids generate_params['negative_prompt_attention_mask'] = negative_prompt_attention_mask + if args.speculative_sampling: + generate_params['target_model'] = model + generate_params['draft_model'] = draft_model + generate_params['draft_k'] = draft_k + generate_params['generation_config'] = GenerationConfig() def generate_with_callback(callback=None, **kwargs): if 'stopping_criteria' in kwargs: @@ -444,7 +517,10 @@ def generate_with_callback(callback=None, **kwargs): kwargs['stopping_criteria'] = [Stream(callback_func=callback)] clear_torch_cache() with torch.no_grad(): - model.generate(**kwargs) + if not args.speculative_sampling: + model.generate(**kwargs) + else: # enable speculative sampling + speculative_sample(**kwargs) def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) @@ -549,6 +625,14 @@ def generate_with_streaming(**kwargs): label="Presence Penalty", interactive=True, visible=True if args.use_vllm else False) + draft_k = gr.Slider( + 0, + 10, + value=0, + step=1.0, + label="Draft K", + interactive=True, + visible=args.speculative_sampling==True) params = [user_input, chatbot] predict_params = [ @@ -562,7 +646,8 @@ def generate_with_streaming(**kwargs): do_sample, repetition_penalty, guidance_scale, - presence_penalty] + presence_penalty, + draft_k] submitBtn.click( user, diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index ef0e658..1cb029c 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -27,6 +27,12 @@ parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT, help="The system prompt of the prompt template.") parser.add_argument('--negative_prompt', type=str, default=None, help="Negative prompt in CFG sampling.") parser.add_argument('--guidance_scale', type=float, default=1.0, help="The guidance scale for CFG sampling. CFG is enabled by setting `guidance_scale > 1`.") +parser.add_argument('--speculative_sampling', action='store_true', help="Use speculative sampling to speed up inference.") +parser.add_argument('--draft_k', type=int, default=-1, help="Number of new tokens the draft model generates each times. Should be a positive integer. Using adaptive number K if `draft_k <= 0`.") +parser.add_argument('--draft_base_model', default=None, type=str, help="Draft base model used in speculative sampling.") +parser.add_argument('--draft_lora_model', default=None, type=str, help="If None, perform inference on the draft base model") +parser.add_argument('--draft_model_load_in_8bit', action='store_true', help="Load the draft model in the 8bit mode") +parser.add_argument('--draft_model_load_in_4bit', action='store_true', help="Load the draft model in the 4bit mode") args = parser.parse_args() if args.guidance_scale > 1: @@ -44,12 +50,15 @@ raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.") if args.guidance_scale > 1: raise ValueError("guidance_scale > 1, but vLLM does not support CFG sampling. Please unset guidance_scale. ") + if args.speculative_sampling: + raise ValueError("speculative_sampling is set, but vLLM does not support speculative sampling. Please unset speculative_sampling. ") if args.load_in_8bit and args.load_in_4bit: raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") if args.only_cpu is True: args.gpus = "" if args.load_in_8bit or args.load_in_4bit: raise ValueError("Quantization is unavailable on CPU.") + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus import torch from transformers import LlamaForCausalLM, LlamaTokenizer @@ -66,6 +75,12 @@ if not args.only_cpu: apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) +if args.speculative_sampling: + if args.draft_base_model == None: + raise ValueError("Speculative sampling requires a draft model. Please specify the draft model.") + if args.draft_model_load_in_8bit and args.draft_model_load_in_4bit: + raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") + from speculative_sample import speculative_sample if args.use_vllm: generation_config = dict( @@ -125,6 +140,23 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None + ) + + if args.speculative_sampling: + if args.load_in_4bit or args.load_in_8bit: + draft_quantization_config = BitsAndBytesConfig( + load_in_4bit=args.draft_model_load_in_4bit, + load_in_8bit=args.draft_model_load_in_8bit, + bnb_4bit_compute_dtype=load_type, + ) + draft_base_model = LlamaForCausalLM.from_pretrained( + args.draft_base_model, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + load_in_4bit=args.draft_model_load_in_4bit, + load_in_8bit=args.draft_model_load_in_8bit, + quantization_config=draft_quantization_config if (args.draft_model_load_in_4bit or args.draft_model_load_in_8bit) else None ) model_vocab_size = base_model.get_input_embeddings().weight.size(0) @@ -134,15 +166,31 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): if model_vocab_size!=tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenizer_vocab_size) + if args.speculative_sampling: + draft_model_vocab_size = draft_base_model.get_input_embeddings().weight.size(0) + print(f"Vocab of the draft base model: {draft_model_vocab_size}") + if draft_model_vocab_size!=tokenizer_vocab_size: + print("Resize draft model embeddings to fit tokenizer") + draft_base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',).half() else: model = base_model + if args.speculative_sampling: + if args.draft_lora_model is not None: + print("loading peft draft model") + draft_model = PeftModel.from_pretrained(draft_base_model, args.draft_lora_model,torch_dtype=load_type,device_map='auto',).half() + else: + draft_model = draft_base_model if device==torch.device('cpu'): model.float() model.eval() + if args.speculative_sampling: + if device==torch.device('cpu'): + draft_model.float() + draft_model.eval() # test data if args.data_file is None: @@ -184,13 +232,24 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): else: inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? if args.guidance_scale ==1: - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config - ) + if not args.speculative_sampling: + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config + ) + else: # enable speculative sampling + generation_output = speculative_sample( + input_ids=inputs["input_ids"].to(device), + target_model=model, + draft_model=draft_model, + draft_k=args.draft_k, + generation_config=generation_config, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) else: # enable CFG sampling if negative_text is None: negative_prompt_ids = None @@ -199,16 +258,30 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): negative_inputs = tokenizer(negative_text,return_tensors="pt") negative_prompt_ids = negative_inputs["input_ids"].to(device) negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device) - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config, - guidance_scale = args.guidance_scale, - negative_prompt_ids = negative_prompt_ids, - negative_prompt_attention_mask = negative_prompt_attention_mask - ) + if not args.speculative_sampling: + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config, + guidance_scale = args.guidance_scale, + negative_prompt_ids = negative_prompt_ids, + negative_prompt_attention_mask = negative_prompt_attention_mask + ) + else: # enable speculative sampling + generation_output = speculative_sample( + input_ids=inputs["input_ids"].to(device), + target_model=model, + draft_model=draft_model, + draft_k=args.draft_k, + generation_config=generation_config, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + guidance_scale=args.guidance_scale, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) s = generation_output[0] output = tokenizer.decode(s,skip_special_tokens=True) if args.with_prompt: @@ -247,13 +320,24 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): negative_text = args.negative_prompt inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? if args.guidance_scale == 1: - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config - ) + if not args.speculative_sampling: + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config + ) + else: # enable speculative sampling + generation_output = speculative_sample( + input_ids=inputs["input_ids"].to(device), + target_model=model, + draft_model=draft_model, + draft_k=args.draft_k, + generation_config=generation_config, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) else: # enable CFG sampling if negative_text is None: negative_prompt_ids = None @@ -262,16 +346,30 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): negative_inputs = tokenizer(negative_text,return_tensors="pt") negative_prompt_ids = negative_inputs["input_ids"].to(device) negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device) - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config, - guidance_scale = args.guidance_scale, - negative_prompt_ids = negative_prompt_ids, - negative_prompt_attention_mask = negative_prompt_attention_mask - ) + if not args.speculative_sampling: + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config, + guidance_scale = args.guidance_scale, + negative_prompt_ids = negative_prompt_ids, + negative_prompt_attention_mask = negative_prompt_attention_mask + ) + else: # enable speculative sampling + generation_output = speculative_sample( + input_ids=inputs["input_ids"].to(device), + target_model=model, + draft_model=draft_model, + draft_k=args.draft_k, + generation_config=generation_config, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + guidance_scale=args.guidance_scale, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) s = generation_output[0] output = tokenizer.decode(s,skip_special_tokens=True) if args.with_prompt: diff --git a/scripts/inference/speculative_sample.py b/scripts/inference/speculative_sample.py new file mode 100644 index 0000000..2c1299e --- /dev/null +++ b/scripts/inference/speculative_sample.py @@ -0,0 +1,486 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, PreTrainedModel +from transformers import ( + LogitsProcessorList, + StoppingCriteriaList, +) +from transformers.generation.streamers import BaseStreamer +import torch +from typing import Tuple, List, Optional +import copy + + +def norm_logits( + x: torch.Tensor, + logits: torch.Tensor, + logits_processor: LogitsProcessorList, + logits_warper: LogitsProcessorList, + do_sample: bool = False, + cur_len=None, +) -> torch.Tensor: + """ + Args: + x (`torch.Tensor`): input ids, shape (batch, seqlen) + logits `(`torch.Tensor`): shape (batch, seqlen, vocab) + do_sample ('bool'): whether do sample + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + logits_warper (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from + [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + do_sample ('boo;'): whether do sample. + cur_len ('int'): length of current decoded tokens. + + Returns: + `torch.Tensor`: probs with shape as (batch, seq_len) + """ + new_logits = logits[:,:] + if len(logits_processor) > 0: + for i in range(x.shape[1]-cur_len+1): + new_logits[:,i,:] = logits_processor(x[:,:cur_len+i], new_logits[:,i,:]) + if do_sample and len(logits_warper) > 0: + for i in range(x.shape[1]-cur_len+1): + new_logits[:,i,:] = logits_warper(x[:,:cur_len+i], new_logits[:,i,:]) + + probs = new_logits.softmax(dim=-1) + + return probs + + +def sample(probs : torch.Tensor, do_sample : bool = False, num_samples: int = 1): + if do_sample: + new_token = torch.multinomial(probs, num_samples=num_samples) + else: + new_token = torch.argmax(probs, keepdim=True) + return new_token + + +def max_fn(x): + """ + norm(max (x, 0)) + """ + x_max = torch.where(x > 0, x, torch.zeros_like(x)) + x_max_sum = torch.sum(x_max, dim=1, keepdim=True) + return x_max / x_max_sum + + +def _draft_model_serial_forward( + prefix : torch.Tensor, + draft_k : int, + draft_model : torch.nn.Module, + logits_processor, + logits_warper, + do_sample=False, + past_key_values=None, + rejected=False, + eos_token_id_tensor = None +) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor or bool]: + """ forward draft model draft_k times + + Args: + prefix (`torch.Tensor`): the original input ids + draft_k (`int`): how many times draft model forward and sample + draft_model (`torch.nn.Module`): an draft model + logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and + generation config. + logits_warper: List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution + do_sample (`bool`): whether do sample + past_key_values: kv cache of draft model in last iteration + rejected (`bool`): whether any of tokens in last iteration was rejected + eos_token_id_tensor (`torch.Tensor`): eos token id in tokenizer + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor or bool]: + generated tokens, probability distribution of draft model's output, + past_key_values of draft model, flag of whether last token is eos + """ + x = prefix + x = x.to(draft_model.device) + input_ids = x + probs = None + + if past_key_values != None: + if rejected == False: + output = draft_model(input_ids[:,-2:-1], past_key_values = past_key_values, use_cache=True) + past_key_values = output.past_key_values + input_ids = input_ids[:,-1:] + probs = norm_logits(x[:,:-1], output.logits, logits_processor, logits_warper, do_sample, x.shape[1]-1) + else: + input_ids = input_ids[:,-1:] + + for _ in range(draft_k): + output = draft_model(input_ids, past_key_values = past_key_values, use_cache=True) + new_probs = norm_logits(x, output.logits[:,-1:], logits_processor, logits_warper, do_sample, x.shape[1]) + next_tok = sample(new_probs[:, -1, :], do_sample=do_sample) + if eos_token_id_tensor is not None: + last_token_is_eos = next_tok.tile(eos_token_id_tensor.shape[0], 1) + last_token_is_eos = ( + ~last_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_token_is_eos: + break + else: + last_token_is_eos = False + past_key_values = output.past_key_values + probs = torch.cat((probs, new_probs), dim=1) if probs != None else torch.cat((output.logits[:,:-1], new_probs), dim=1) + input_ids = next_tok + x = torch.cat((x, next_tok), dim=1) + + return x, probs, past_key_values, last_token_is_eos + +def _speculative_sampling( + prefix : torch.Tensor, + target_model : torch.nn.Module, + draft_model : torch.nn.Module, + max_new_tokens : int , + draft_k : int = 4, + logits_processor: LogitsProcessorList = None, + logits_warper : LogitsProcessorList = None, + do_sample = False, + eos_token_id = None, + stopping_criteria = None, + streamer: Optional["BaseStreamer"] = None, +) -> torch.Tensor: + """ + DeepMind version Speculative Sampling. + Accelerating Large Language Model Decoding with Speculative Sampling + https://arxiv.org/abs/2302.01318 + + Args: + prefix (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. + target_model (torch.nn.Module): target model, the large one + draft_model (torch.nn.Module): draft model, the small one + max_new_tokens (int): the max overall generated tokens number. + draft_k (int): the token number small model guesses. + logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and + generation config. + logits_warper: List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution + do_sample (`bool`): whether do sample + eos_token_id: eos token id in tokenizer + stopping_criteria: An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + + Returns: + torch.Tensor: generated tokens (batch, target_seqlen) + """ + input_seq_len = prefix.shape[1] + T = input_seq_len + max_new_tokens + assert prefix.shape[0] == 1, "input batch size must be 1" + + if draft_k <= 0: + draft_k = 4 + adaptive_k = True + else: + adaptive_k = False + + draft_past_key_values = None + draft_probs = None + target_past_key_values = None + target_probs = None + rejected = False + unfinished_sequences = prefix.new(prefix.shape[0]).fill_(1) + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(prefix.device) if eos_token_id is not None else None + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + while prefix.shape[1] < T: + prefix_len = prefix.shape[1] + x, new_draft_probs, draft_past_key_values, _ = _draft_model_serial_forward( + prefix, + draft_k, + draft_model, + logits_processor, + logits_warper, + do_sample, + draft_past_key_values, + rejected, + eos_token_id_tensor + ) + + if draft_probs != None and new_draft_probs != None: + draft_probs = torch.concat((draft_probs, new_draft_probs), dim=1) + elif new_draft_probs == None: + draft_probs = draft_probs + else: + draft_probs = new_draft_probs + + if target_past_key_values != None: + unchecked_token_count = x.shape[1] - target_probs.shape[1] - 1 + outputs = target_model(x[:,-(unchecked_token_count+1):], past_key_values=target_past_key_values, use_cache=True) + else: + unchecked_token_count = x.shape[1] - prefix_len + outputs = target_model(x, use_cache=True) + new_target_probs = norm_logits(x, outputs.logits[:,-(unchecked_token_count+1):], logits_processor, logits_warper, do_sample, prefix_len) + target_probs = torch.cat((target_probs, new_target_probs), dim=1) if target_probs != None else torch.cat((outputs.logits[:,:-(unchecked_token_count+1)], new_target_probs), dim=1) + target_past_key_values = outputs.past_key_values + + # n_valid: the length of the valid prefix + is_all_accept = True + n_valid = prefix_len + for i in range(unchecked_token_count): + r = torch.rand(1, device = target_probs.device) + cur_token_id = x[:, prefix_len + i] + cur_pos = prefix_len + i - 1 + + if r < torch.min( + torch.tensor([1], device=draft_probs.device), + target_probs[:, cur_pos, cur_token_id] / draft_probs[:, cur_pos, cur_token_id] + ): + # accept, and update n_valid + n_valid += 1 + else: + # reject + target_new_token = sample( + max_fn( + target_probs[:, n_valid-1, :] - draft_probs[:, n_valid-1, :] + ), do_sample=do_sample + ) + is_all_accept = False + rejected = True + break + + n_valid = min(n_valid, T - 1) + prefix = x[:, :n_valid] + + if is_all_accept: + target_new_token = sample(target_probs[:, -1, :], do_sample=do_sample) + rejected = False + else: + draft_probs = draft_probs[:,:n_valid,:] + target_probs = target_probs[:,:n_valid,:] + if "bloom" in draft_model.__class__.__name__.lower() or ( + draft_model.config.architectures is not None and "bloom" in draft_model.config.architectures[0].lower() + ): + draft_past_key_values = [ + (key[:,:,:n_valid], value[:,:n_valid,:]) + for key,value in draft_past_key_values + ] + target_past_key_values = [ + (key[:,:,:n_valid], value[:,:n_valid,:]) + for key,value in target_past_key_values + ] + else: + draft_past_key_values = [ + (key[:,:,:n_valid,:], value[:,:,:n_valid,:]) + for key,value in draft_past_key_values + ] + target_past_key_values = [ + (key[:,:,:n_valid,:], value[:,:,:n_valid,:]) + for key,value in target_past_key_values + ] + if adaptive_k: + if is_all_accept: + draft_k += 2 + else: + draft_k = max(1, draft_k - 1) + prefix = torch.cat((prefix, target_new_token), dim=1) + if streamer is not None: + streamer.put(prefix.cpu()) + if stopping_criteria(prefix, target_probs): + # this_peer_finished = True + break + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + prefix[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + ) + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + # this_peer_finished = True + break + + if streamer is not None: + streamer.end() + + return prefix + + +def speculative_sample( + input_ids, + target_model: Optional["PreTrainedModel"], + draft_model: Optional["PreTrainedModel"], + generation_config: GenerationConfig, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + draft_k: int = 4, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + inputs_tensor, _, model_kwargs = target_model._prepare_model_inputs( + input_ids, generation_config.bos_token_id, model_kwargs + ) + + model_kwargs["use_cache"] = generation_config.use_cache + + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + # warnings.warn( + # f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + # "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + # " recommend using `max_new_tokens` to control the maximum length of the generation.", + # UserWarning, + # ) + pass + elif generation_config.max_new_tokens is not None: + # if not has_default_max_length: + # logger.warning( + # f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + # f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + # "Please refer to the documentation for more information. " + # "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + # ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + # input_ids_string = "decoder_input_ids" if target_model.config.is_encoder_decoder else "input_ids" + # logger.warning( + # f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + # f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + # " increasing `max_new_tokens`." + # ) + pass + # prepare logis_processor, stopping_criteria, logits_warper + try: + logits_processor = target_model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + except TypeError: + # Please install the latest transformers (commit equal or later than d533465) to enable CFG sampling. + logits_processor = target_model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + ) + stopping_criteria = target_model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper=target_model._get_logits_warper(generation_config) if generation_config.do_sample else None + + outputs = _speculative_sampling( + prefix=input_ids, + target_model=target_model, + draft_model=draft_model, + max_new_tokens=generation_config.max_new_tokens, + draft_k=draft_k, + logits_processor=logits_processor, + logits_warper=logits_warper, + do_sample=generation_config.do_sample, + eos_token_id=generation_config.eos_token_id, + stopping_criteria=stopping_criteria, + streamer=streamer, + ) + + return outputs + + +if __name__ == "__main__": + # A usage example + draft_model_name = 'Draft/Model/Path' + target_model_name = 'Target/Model/Path' + + DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" + + TEMPLATE = ( + "[INST] <>\n" + "{system_prompt}\n" + "<>\n\n" + "{instruction} [/INST]" + ) + + def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT): + return TEMPLATE.format_map({'instruction': instruction,'system_prompt': system_prompt}) + + inputs = ["我能用lightning数据线给安卓手机充电吗?"] + + negative_text = generate_prompt(inputs[0], system_prompt="回复尽可能多的内容。") + inputs = [generate_prompt(text) for text in inputs] + + tokenizer = AutoTokenizer.from_pretrained(target_model_name) + + print("begin loading models") + draft_model = AutoModelForCausalLM.from_pretrained( + draft_model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map='auto', + load_in_8bit=False + ) + draft_model.resize_token_embeddings(len(tokenizer)) + print(f"Load {draft_model_name}") + target_model = AutoModelForCausalLM.from_pretrained( + target_model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map='auto', + load_in_8bit=False + ) + print(f"Load {target_model_name}") + draft_model.eval() + target_model.eval() + print("finish loading models") + + torch_device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + input_ids = tokenizer.encode(inputs[0], return_tensors='pt').to(torch_device) + + negative_inputs = tokenizer(negative_text,return_tensors="pt") + negative_prompt_ids = negative_inputs["input_ids"].to(torch_device) + negative_prompt_attention_mask = negative_inputs["attention_mask"].to(torch_device) + + generation_config = GenerationConfig( + temperature=0.2, + top_k=40, + top_p=0.9, + do_sample=True, + num_beams=1, + repetition_penalty=1.1, + max_new_tokens=128 + ) + + outputs = speculative_sample( + input_ids=input_ids, + target_model=target_model, + draft_model=draft_model, + generation_config=generation_config, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + # draft_k=4, + # guidance_scale=1.5, + # negative_prompt_ids=negative_prompt_ids, + # negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(generated_text)