Skip to content

Commit

Permalink
Add fallback for is_compiling (#6663)
Browse files Browse the repository at this point in the history
Importing `torch.compiler.is_compiling` causes an error with an older
version of PyTorch.
This PR adds a fallback for `is_compiling` to use an equivalent function
of older PyTorch versions.

This will resolve #6656.
  • Loading branch information
tohtana authored Oct 25, 2024
1 parent 3d5cf73 commit d5da746
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
13 changes: 13 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

import torch

try:
from torch.compiler import is_compiling as torch_is_compiling
except ImportError:
try:
from torch._dynamo.external_utils import is_compiling as torch_is_compiling
except ImportError:
# Torch does not have compiler support
torch_is_compiling = lambda: False


def is_compile_supported():
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
Expand All @@ -14,3 +23,7 @@ def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func


def is_compiling():
return torch_is_compiling()
7 changes: 3 additions & 4 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import logging
import sys
import os
import torch
from deepspeed.runtime.compiler import is_compile_supported
from deepspeed.runtime.compiler import is_compile_supported, is_compiling

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -26,7 +25,7 @@ def create_warning_filter(logger):

def warn_once(record):
nonlocal warn
if is_compile_supported() and torch.compiler.is_compiling() and not warn:
if is_compile_supported() and is_compiling() and not warn:
warn = True
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
Expand All @@ -39,7 +38,7 @@ def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.compiler.is_compiling():
if is_compiling():
return
else:
return func(*args, **kwargs)
Expand Down

0 comments on commit d5da746

Please sign in to comment.