From 1c6fd400d4227ffbf836cb079539994dbdf64173 Mon Sep 17 00:00:00 2001 From: shengguangming Date: Fri, 17 Jan 2025 21:26:17 +0800 Subject: [PATCH] assert transformers version for ulysses --- verl/models/transformers/monkey_patch.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index cf86e42..528ac3c 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -40,6 +40,11 @@ def apply_monkey_patch_to_qwen2(): def apply_monkey_patch(config: PretrainedConfig, verbose=True): + if not is_transformers_version_in_range("4.45.0", "4.47.1"): + raise AssertionError( + "The installed `transformers` version doesn't support ulysses patch. " + "Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature." + ) success_apply_monkey_patch = False if config.model_type in _PATCH_NAME_TO_FUNC: _PATCH_NAME_TO_FUNC[config.model_type]() @@ -52,3 +57,18 @@ def apply_monkey_patch(config: PretrainedConfig, verbose=True): please set `ulysses_sequence_parallel_size=1`') return success_apply_monkey_patch + +from functools import lru_cache +from packaging import version +import importlib.metadata + +@lru_cache() +def is_transformers_version_in_range(min_version: str, max_version: str) -> bool: + try: + # Get the installed version of the transformers library + transformers_version = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError: + raise ModuleNotFoundError("The `transformers` package is not installed.") + + # Check if the version is within the specified range + return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version) \ No newline at end of file