Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Supervisor supports restarts #2611

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
8 changes: 6 additions & 2 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip install -U soundfile
${{ env.SELF_HOST_PYTHON }} -m pip install -U sentence-transformers
${{ env.SELF_HOST_PYTHON }} -m pip install -U FlagEmbedding
${{ env.SELF_HOST_PYTHON }} -m pip install "llama-cpp-python>=0.2.82" -i https://abetlen.github.io/llama-cpp-python/whl/cu124
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
--disable-warnings \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \
Expand Down Expand Up @@ -232,7 +233,10 @@ jobs:
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_fish_speech.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/embedding/tests/test_integrated_embedding.py
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/embedding/tests/test_integrated_embedding.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_restart_supervisor.py
elif [ "$MODULE" == "metal" ]; then
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
Expand All @@ -246,6 +250,6 @@ jobs:
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/core/tests/test_continuous_batching.py --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests --ignore xinference/model/embedding/tests/test_integrated_embedding.py xinference
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/core/tests/test_restart_supervisor.py --ignore xinference/core/tests/test_continuous_batching.py --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests --ignore xinference/model/embedding/tests/test_integrated_embedding.py xinference
fi
working-directory: .
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
files: xinference
repos:
- repo: https://github.com/psf/black
- repo: https://github.com/psf/black-pre-commit-mirror
paradin marked this conversation as resolved.
Show resolved Hide resolved
rev: 23.12.0
hooks:
- id: black
Expand Down
51 changes: 51 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,57 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]:
v["replica"] = self._model_uid_to_replica_info[k].replica
return running_model_info

# Receive model infos of workers
@log_async(logger=logger)
async def sync_models(
self, worker_address: str, model_desc: Dict[str, Dict[str, Any]]
): # model_uid : ModelDescription{"address"}
for replica_model_uid, desc_dict in model_desc.items():
# Rebuild self._replica_model_uid_to_worker
if replica_model_uid in self._replica_model_uid_to_worker:
continue

model_name = desc_dict["model_name"] if "model_name" in desc_dict else ""
model_version = (
desc_dict["model_version"] if "model_version" in desc_dict else ""
)
logger.debug(
f"Receive model replica: {replica_model_uid} {worker_address} {model_name}"
)

assert (
worker_address in self._worker_address_to_worker
), f"Worker {worker_address} not exists when sync_models"

self._replica_model_uid_to_worker[
replica_model_uid
] = self._worker_address_to_worker[worker_address]

# Rebuild self._model_uid_to_replica_info
model_uid, rep_id = parse_replica_model_uid(replica_model_uid)
replica = rep_id + 1
if model_uid not in self._model_uid_to_replica_info:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)
else:
if replica > self._model_uid_to_replica_info[model_uid].replica:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)

# Rebuild self._status_guard_ref
instance_info = InstanceInfo(
model_name=model_name,
model_uid=model_uid,
model_version=model_version,
model_ability=[],
replica=replica,
status=LaunchStatus.READY.name,
instance_created_ts=int(time.time()),
)
await self._status_guard_ref.set_instance_info(model_uid, instance_info)

def is_local_deployment(self) -> bool:
# TODO: temporary.
return (
Expand Down
97 changes: 97 additions & 0 deletions xinference/core/tests/test_restart_supervisor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import time
from typing import Dict, Optional

import xoscar as xo

from ...api import restful_api
from ...client import Client


def test_restart_supervisor():
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
from ...deploy.worker import main as _start_worker

def worker_run_in_subprocess(
address: str, supervisor_address: str, logging_conf: Optional[Dict] = None
) -> multiprocessing.Process:
p = multiprocessing.Process(
target=_start_worker,
args=(address, supervisor_address, None, None, logging_conf),
)
p.start()
return p

# start supervisor
web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port()
supervisor_address = f"127.0.0.1:{supervisor_port}"
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
rest_api_proc = multiprocessing.Process(
target=restful_api.run,
kwargs=dict(
supervisor_address=supervisor_address, host="127.0.0.1", port=web_port
),
)
rest_api_proc.start()

time.sleep(5)

# start worker
proc_worker = worker_run_in_subprocess(
address=f"127.0.0.1:{xo.utils.get_next_port()}",
supervisor_address=supervisor_address,
)

time.sleep(10)

client = Client(f"http://127.0.0.1:{web_port}")

try:
model_uid = "qwen1.5-chat"
client.launch_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="llama.cpp",
)

# query replica info
model_replica_info = client.describe_model(model_uid)
assert model_replica_info is not None

# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()

time.sleep(5)

# restart supervisor
supervisor_run_in_subprocess(supervisor_address)

time.sleep(5)

# check replica info
model_replic_info_check = client.describe_model(model_uid)
assert model_replica_info["replica"] == model_replic_info_check["replica"]

finally:
client.abort_cluster()
proc_supervisor.terminate()
proc_worker.terminate()
proc_supervisor.join()
proc_worker.join()
14 changes: 14 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

# Reconnect to Newly started supervisor, has running models
if add_worker and len(self._model_uid_to_model) > 0:
# Reconnect to Newly started supervisor, notify supervisor
await self._supervisor_ref.add_worker(self.address)
# Sync replica model infos
running_models = {}
running_models.update(await self.list_models())
await self._supervisor_ref.sync_models(self.address, running_models)
logger.info(
f"Connected to supervisor as a old worker with {len(running_models)} models"
)

self._status_guard_ref = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.default_uid()
)
Expand Down Expand Up @@ -1049,6 +1061,8 @@ async def _periodical_report_status(self):
except (
Exception
) as ex: # pragma: no cover # noqa: E722 # nosec # pylint: disable=bare-except
# Disconnect from supervisor, which maybe restart
self._supervisor_ref = None
logger.error(f"Failed to upload node info: {ex}")
try:
await asyncio.sleep(XINFERENCE_HEALTH_CHECK_INTERVAL)
Expand Down
3 changes: 2 additions & 1 deletion xinference/deploy/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@


async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
logging.config.dictConfig(logging_conf) # type: ignore
if logging_conf:
logging.config.dictConfig(logging_conf) # type: ignore

pool = None
try:
Expand Down
Loading