Skip to content

Commit

Permalink
feat: add optional param use_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 2, 2024
1 parent b607174 commit 46200b3
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 16 deletions.
5 changes: 4 additions & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def load(
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
device: Optional[torch.device] = None,
coef: Optional[torch.Tensor] = None,
use_flash_attn=False,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
if download_path is None:
Expand All @@ -136,6 +137,7 @@ def load(
device=device,
compile=compile,
coef=coef,
use_flash_attn=use_flash_attn,
**{
k: os.path.join(download_path, v)
for k, v in OmegaConf.load(
Expand Down Expand Up @@ -255,6 +257,7 @@ def _load(
device: Optional[torch.device] = None,
compile: bool = True,
coef: Optional[str] = None,
use_flash_attn=False,
):
if device is None:
device = select_device()
Expand Down Expand Up @@ -292,7 +295,7 @@ def _load(

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
gpt = GPT(**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.prepare(compile=compile and "cuda" in str(device))
Expand Down
12 changes: 8 additions & 4 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
num_audio_tokens: int,
num_text_tokens: int,
num_vq=4,
use_flash_attn=False,
device=torch.device("cpu"),
logger=logging.getLogger(__name__),
):
Expand All @@ -45,6 +46,8 @@ def __init__(
self.num_vq = num_vq
self.num_audio_tokens = num_audio_tokens

self.use_flash_attn = use_flash_attn

self.gpt = self._build_llama(gpt_config, self.device_gpt)
self.model_dim = int(self.gpt.config.hidden_size)
self.emb_code = nn.ModuleList(
Expand Down Expand Up @@ -96,7 +99,7 @@ def get(self) -> bool:
return self._interrupt

def _build_llama(
self, config: omegaconf.DictConfig, device: torch.device
self, config: omegaconf.DictConfig, device: torch.device,
) -> LlamaModel:

model = None
Expand All @@ -114,11 +117,12 @@ def _build_llama(
)

if model is None:
if is_flash_attn_2_available():
if self.use_flash_attn and is_flash_attn_2_available():
llama_config = LlamaConfig(
**config,
attn_implementation="flash_attention_2",
)
self.logger.warn("enabling flash_attention_2 may make gpt be even slower")
else:
llama_config = LlamaConfig(**config)
model = LlamaModel(llama_config)
Expand All @@ -127,7 +131,7 @@ def _build_llama(
return model.to(device)

def prepare(self, compile=False):
if is_flash_attn_2_available():
if self.use_flash_attn and is_flash_attn_2_available():
self.gpt = self.gpt.to(dtype=torch.float16)
if compile:
try:
Expand Down Expand Up @@ -435,7 +439,7 @@ def generate(
)
del_all(model_input)
attentions.append(outputs.attentions)
hidden_states = outputs.last_hidden_state.to(self.device) # 🐻
hidden_states = outputs.last_hidden_state.to(self.device, dtype=torch.float) # 🐻
past_key_values = outputs.past_key_values
del_all(outputs)
if return_hidden:
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
> [!Note]
> See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
> [!Warning]
> Currently the FlashAttention-2 will slow down the generating speed according to [this issue](https://github.com/huggingface/transformers/issues/26990).
> Only install it on developing purpose.
```bash
pip install flash-attn --no-build-isolation
```
Expand Down
13 changes: 6 additions & 7 deletions examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
"outputs": [],
"source": [
"!pip install -r /content/ChatTTS/requirements.txt\n",
"!ldconfig /usr/lib64-nvidia\n",
"!pip install flash-attn --no-build-isolation"
"!ldconfig /usr/lib64-nvidia"
]
},
{
Expand Down Expand Up @@ -116,7 +115,7 @@
"id": "3Ty427FZNH30"
},
"source": [
"### Here are three choices for loading models:"
"### Here are three choices for loading models,"
]
},
{
Expand All @@ -125,7 +124,7 @@
"id": "NInF7Lk1NH30"
},
"source": [
"#### 1. Load models from Hugging Face:"
"#### 1. Load models from Hugging Face (recommend)"
]
},
{
Expand All @@ -137,7 +136,7 @@
"outputs": [],
"source": [
"# use force_redownload=True if the weights have been updated.\n",
"chat.load(source=\"huggingface\", force_redownload=True)"
"chat.load(source=\"huggingface\")"
]
},
{
Expand All @@ -146,7 +145,7 @@
"id": "AhBD5WUPNH30"
},
"source": [
"#### 2. Load models from local directories 'asset' and 'config':"
"#### 2. Load models from local directories 'asset' and 'config'"
]
},
{
Expand All @@ -167,7 +166,7 @@
"id": "c0qjGPNkNH31"
},
"source": [
"#### 3. Load models from a custom path:"
"#### 3. Load models from a custom path"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Here are three choices for loading models:"
"### Here are three choices for loading models,"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. Load models from Hugging Face:"
"#### 1. Load models from Hugging Face (not suitable in CN)"
]
},
{
Expand All @@ -103,7 +103,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. Load models from local directories 'asset' and 'config':"
"#### 2. Load models from local directories 'asset' and 'config' (recommend)"
]
},
{
Expand All @@ -120,7 +120,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Load models from a custom path:"
"#### 3. Load models from a custom path"
]
},
{
Expand Down

0 comments on commit 46200b3

Please sign in to comment.