Skip to content

Commit

Permalink
Support of distributed inference/training with torchrun and tensor pa…
Browse files Browse the repository at this point in the history
…rallelism with Deepspeed (#87)
  • Loading branch information
IlyasMoutawwakil authored Nov 23, 2023
1 parent 6650bcb commit 48a4410
Show file tree
Hide file tree
Showing 44 changed files with 778 additions and 567 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ jobs:
--gpus '"device=0,1"'
--entrypoint /bin/bash
opt-bench-cuda:${{ matrix.image.cuda_version }}
-c "pip install -e .[test,peft,diffusers] && pytest -k 'cuda and pytorch' -x"
-c "pip install -e .[test,peft,diffusers,deepspeed] && pytest -k 'cuda and pytorch' -x"
2 changes: 1 addition & 1 deletion .github/workflows/test_rocm_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
--device /dev/dri/renderD129
--entrypoint /bin/bash
opt-bench-rocm:${{ matrix.image.rocm_version }}
-c "pip install -e .[test,peft,diffusers] && pytest -k 'cuda and pytorch' -x"
-c "pip install -e .[test,peft,diffusers,deepspeed] && pytest -k 'cuda and pytorch' -x"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,5 @@ version.txt

actions-runner/
experiments/
examples/
examples/
results/
104 changes: 42 additions & 62 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
ClassVar,
Dict,
Generic,
List,
Optional,
Union,
)

import numpy as np
import torch
from optimum.exporters import TasksManager
from transformers import AutoConfig, AutoProcessor

if TYPE_CHECKING:
from datasets import Dataset
from transformers import (
Pipeline,
PretrainedConfig,
PreTrainedModel,
TrainerCallback,
TrainerState,
)
from transformers.utils import ModelOutput
Expand All @@ -37,58 +33,48 @@

from ..task_utils import DIFFUSION_TASKS, TEXT_GENERATION_TASKS
from .config import BackendConfigT
from .isolation_utils import (
only_this_process_is_running_on_cuda_devices,
only_this_process_will_run_on_cuda_devices,
)
from .isolation_utils import check_cuda_continuous_isolation
from .utils import (
extract_shapes_from_diffusion_pipeline,
extract_shapes_from_model_artifacts,
)

LOGGER = getLogger("backend")

CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if CUDA_VISIBLE_DEVICES is not None:
CUDA_DEVICES = list(map(int, CUDA_VISIBLE_DEVICES.split(",")))
elif torch.cuda.is_available():
CUDA_DEVICES = list(range(torch.cuda.device_count()))
else:
CUDA_DEVICES = []


class Backend(Generic[BackendConfigT], ABC):
NAME: ClassVar[str]

# instance variables without default values https://stackoverflow.com/a/44962662
library: str
model_type: str
config: BackendConfigT
isolation_thread: Optional[Process]
pretrained_model: Union["PreTrainedModel", "Pipeline"]
pretrained_processor: Optional["PreTrainedProcessor"]
pretrained_config: Optional["PretrainedConfig"]
automodel_class: Callable[..., "PreTrainedModel"]

def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]):
self.task = task
self.model = model
self.device = device
self.hub_kwargs = hub_kwargs
self.device = torch.device(device)

if self.is_diffusion_pipeline():
# for pipelines
self.library = "diffusers"
self.model_type = self.task
self.pretrained_config = None
self.pretrained_processor = None
else:
# for models
self.library = "transformers"
self.pretrained_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.model, **self.hub_kwargs
)
self.model_type = self.pretrained_config.model_type

try:
# the processor sometimes contains information about the model's
# input shapes that's not available in the config
# sometimes contains information about the model's
# input shapes that're not available in the config
self.pretrained_processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path=self.model, **self.hub_kwargs
)
Expand All @@ -98,7 +84,10 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any
self.pretrained_processor = None

self.automodel_class = TasksManager.get_model_class_for_task(
task=self.task, library=self.library, model_type=self.model_type
framework="pt", # TODO: make this configurable to add support for other frameworks
task=self.task,
library=self.library,
model_type=self.model_type,
)

def is_text_generation_model(self) -> bool:
Expand All @@ -112,71 +101,50 @@ def configure(self, config: BackendConfigT) -> None:
self.config = config

# isolation options
if self.config.initial_isolation_check:
self.check_initial_isolation()
if self.config.continous_isolation_check:
if self.config.continuous_isolation:
LOGGER.info("\t+ Running continuous isolation check")
self.check_continuous_isolation()

# seeding backend
LOGGER.info(f"\t+ Seeding backend with seed {self.config.seed}")
self.seed()

# clean up options
if self.config.delete_cache:
LOGGER.info("\t+ Model cache will be deleted after benchmark")

def check_initial_isolation(self) -> None:
if self.device.type == "cuda":
LOGGER.info(f"\t+ Checking initial device(s) isolation of CUDA device(s): {CUDA_DEVICES}")
only_this_process_is_running_on_cuda_devices(cuda_devices=CUDA_DEVICES, benchmark_pid=os.getpid())

def check_continuous_isolation(self) -> None:
if self.device.type == "cuda":
LOGGER.info(f"\t+ Checking continuous device(s) isolation of CUDA device(s): {CUDA_DEVICES}")
self.isolation_thread = Process(
target=only_this_process_will_run_on_cuda_devices,
kwargs={"cuda_devices": CUDA_DEVICES, "benchmark_pid": os.getpid()},
if self.device == "cuda":
self.isolation_process = Process(
target=check_cuda_continuous_isolation,
kwargs={
"isolated_pid": os.getpid(),
"isolation_check_interval": self.config.isolation_check_interval,
},
daemon=True,
)
self.isolation_thread.start()
self.isolation_process.start()
LOGGER.info(f"\t+ Started isolation process with PID {self.isolation_process.pid}")

def seed(self) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)

def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]:
if self.is_diffusion_pipeline():
# diffusion pipelines takes a list of strings
return input
return input # diffusion pipelines takes a list of strings
else:
# models expect tensors on the target device
for key, value in input.items():
input[key] = value.to(self.device)
input[key] = value.to(self.device) # models expect tensors on the target device

return input

def prepare_for_inference(self, **kwargs) -> None:
pass

# # symbolic tracing in transformers requires input names
# def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
# pass

def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model(**input, **kwargs)

def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model.generate(**input, **kwargs)

def train(
self,
training_dataset: "Dataset",
training_arguments: Dict[str, Any],
training_callbacks: List["TrainerCallback"],
training_data_collator: Callable,
) -> "TrainerState":
def train(self, **kwargs) -> "TrainerState":
raise NotImplementedError("Backend must implement train method")

@property
Expand All @@ -194,9 +162,8 @@ def model_shapes(self) -> Dict[str, int]:
return model_shapes

def delete_pretrained_model(self) -> None:
if hasattr(self, "pretrained_model"):
LOGGER.info("\t+ Deleting pretrained model")
del self.pretrained_model
LOGGER.info("\t+ Deleting pretrained model")
del self.pretrained_model
gc.collect()

def delete_model_cache(self) -> None:
Expand All @@ -205,9 +172,22 @@ def delete_model_cache(self) -> None:
model_cache_path = os.path.join(os.path.expanduser("~/.cache/huggingface/hub"), model_cache_folder)
shutil.rmtree(model_cache_path, ignore_errors=True)

def terminate_isolation_process(self) -> None:
LOGGER.info("\t+ Terminating isolation process")
self.isolation_process.kill()
self.isolation_process.join()
self.isolation_process.close()

def clean(self) -> None:
LOGGER.info(f"Cleaning {self.NAME} backend")
self.delete_pretrained_model()

if self.config.continuous_isolation:
self.terminate_isolation_process()

if hasattr(self, "pretrained_model"):
self.delete_pretrained_model()

if self.config.delete_cache:
self.delete_model_cache()

gc.collect()
9 changes: 6 additions & 3 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class BackendConfig(ABC):
inter_op_num_threads: Optional[int] = None
intra_op_num_threads: Optional[int] = None

# isolation options
initial_isolation_check: bool = True
continous_isolation_check: bool = True
# device isolation options
continuous_isolation: bool = True
isolation_check_interval: Optional[int] = None

# clean up options
delete_cache: bool = False
Expand All @@ -32,5 +32,8 @@ def __post_init__(self):
if self.intra_op_num_threads == -1:
self.intra_op_num_threads = cpu_count()

if self.isolation_check_interval is None:
self.isolation_check_interval = 1 # 1 second


BackendConfigT = TypeVar("BackendConfigT", bound=BackendConfig)
97 changes: 0 additions & 97 deletions optimum_benchmark/backends/ddp_utils.py

This file was deleted.

Loading

0 comments on commit 48a4410

Please sign in to comment.