Skip to content

Commit

Permalink
Add workaround for FusedRoPE (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Oct 18, 2023
1 parent 47be147 commit e2d2a56
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
14 changes: 10 additions & 4 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@

import torch

from ....utils import get_device_name


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused kernel for scaled_dot_product_attention")
FusedSDPA = None

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV1 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
# TODO: remove this workaround when FusedRoPE properly works on Gaudi
if get_device_name() == "gaudi2":
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV1 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None
else:
FusedRoPE = None

import habana_frameworks.torch.core as htcore
Expand Down
14 changes: 10 additions & 4 deletions optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, apply_rotary_pos_emb, logger

from ....utils import get_device_name

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")

# TODO: remove this workaround when FusedRoPE properly works on Gaudi
if get_device_name() == "gaudi2":
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None
else:
FusedRoPE = None


Expand Down
14 changes: 10 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
logger,
)

from ....utils import get_device_name

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")

# TODO: remove this workaround when FusedRoPE properly works on Gaudi
if get_device_name() == "gaudi2":
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None
else:
FusedRoPE = None

try:
Expand Down
18 changes: 18 additions & 0 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,21 @@ def check_optimum_habana_min_version(min_version):
"`pip install git+https://github.com/huggingface/optimum-habana.git`."
)
raise ImportError(error_message)


def get_device_name():
"""
Returns the name of the current device: Gaudi or Gaudi2.
Inspired from: https://github.com/HabanaAI/Model-References/blob/a87c21f14f13b70ffc77617b9e80d1ec989a3442/PyTorch/computer_vision/classification/torchvision/utils.py#L274
"""
import habana_frameworks.torch.utils.experimental as htexp

device_type = htexp._get_device_type()

if device_type == htexp.synDeviceType.synDeviceGaudi:
return "gaudi"
elif device_type == htexp.synDeviceType.synDeviceGaudi2:
return "gaudi2"
else:
raise ValueError(f"Unsupported device: the device type is {device_type}.")

0 comments on commit e2d2a56

Please sign in to comment.