Skip to content

Commit

Permalink
assert transformers version for ulysses
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 17, 2025
1 parent 83bb9cc commit 1c6fd40
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def apply_monkey_patch_to_qwen2():


def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
raise AssertionError(
"The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature."
)
success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]()
Expand All @@ -52,3 +57,18 @@ def apply_monkey_patch(config: PretrainedConfig, verbose=True):
please set `ulysses_sequence_parallel_size=1`')

return success_apply_monkey_patch

from functools import lru_cache
from packaging import version
import importlib.metadata

@lru_cache()
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError:
raise ModuleNotFoundError("The `transformers` package is not installed.")

# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)

0 comments on commit 1c6fd40

Please sign in to comment.