diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 994de94484..951999c712 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -200,6 +200,7 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pip install -U soundfile ${{ env.SELF_HOST_PYTHON }} -m pip install -U sentence-transformers ${{ env.SELF_HOST_PYTHON }} -m pip install -U FlagEmbedding + ${{ env.SELF_HOST_PYTHON }} -m pip install "llama-cpp-python>=0.2.82" -i https://abetlen.github.io/llama-cpp-python/whl/cu124 ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ --disable-warnings \ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \ @@ -232,7 +233,10 @@ jobs: --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_fish_speech.py && \ ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ - --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/embedding/tests/test_integrated_embedding.py + --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/embedding/tests/test_integrated_embedding.py && \ + ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ + -W ignore::PendingDeprecationWarning \ + --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_restart_supervisor.py elif [ "$MODULE" == "metal" ]; then pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ @@ -246,6 +250,6 @@ jobs: --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ - --cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/core/tests/test_continuous_batching.py --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests --ignore xinference/model/embedding/tests/test_integrated_embedding.py xinference + --cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/core/tests/test_restart_supervisor.py --ignore xinference/core/tests/test_continuous_batching.py --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests --ignore xinference/model/embedding/tests/test_integrated_embedding.py xinference fi working-directory: . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61d19c86ae..895ea70dc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ files: xinference repos: - - repo: https://github.com/psf/black + - repo: https://github.com/psf/black-pre-commit-mirror rev: 23.12.0 hooks: - id: black diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index c8f2f59ff6..1ab08a54fc 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -1178,6 +1178,57 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]: v["replica"] = self._model_uid_to_replica_info[k].replica return running_model_info + # Receive model infos of workers + @log_async(logger=logger) + async def sync_models( + self, worker_address: str, model_desc: Dict[str, Dict[str, Any]] + ): # model_uid : ModelDescription{"address"} + for replica_model_uid, desc_dict in model_desc.items(): + # Rebuild self._replica_model_uid_to_worker + if replica_model_uid in self._replica_model_uid_to_worker: + continue + + model_name = desc_dict["model_name"] if "model_name" in desc_dict else "" + model_version = ( + desc_dict["model_version"] if "model_version" in desc_dict else "" + ) + logger.debug( + f"Receive model replica: {replica_model_uid} {worker_address} {model_name}" + ) + + assert ( + worker_address in self._worker_address_to_worker + ), f"Worker {worker_address} not exists when sync_models" + + self._replica_model_uid_to_worker[ + replica_model_uid + ] = self._worker_address_to_worker[worker_address] + + # Rebuild self._model_uid_to_replica_info + model_uid, rep_id = parse_replica_model_uid(replica_model_uid) + replica = rep_id + 1 + if model_uid not in self._model_uid_to_replica_info: + self._model_uid_to_replica_info[model_uid] = ReplicaInfo( + replica=replica, scheduler=itertools.cycle(range(replica)) + ) + else: + if replica > self._model_uid_to_replica_info[model_uid].replica: + self._model_uid_to_replica_info[model_uid] = ReplicaInfo( + replica=replica, scheduler=itertools.cycle(range(replica)) + ) + + # Rebuild self._status_guard_ref + instance_info = InstanceInfo( + model_name=model_name, + model_uid=model_uid, + model_version=model_version, + model_ability=[], + replica=replica, + status=LaunchStatus.READY.name, + instance_created_ts=int(time.time()), + ) + await self._status_guard_ref.set_instance_info(model_uid, instance_info) + def is_local_deployment(self) -> bool: # TODO: temporary. return ( diff --git a/xinference/core/tests/test_restart_supervisor.py b/xinference/core/tests/test_restart_supervisor.py new file mode 100644 index 0000000000..249dcfdbf4 --- /dev/null +++ b/xinference/core/tests/test_restart_supervisor.py @@ -0,0 +1,97 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import time +from typing import Dict, Optional + +import xoscar as xo + +from ...api import restful_api +from ...client import Client + + +def test_restart_supervisor(): + from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess + from ...deploy.worker import main as _start_worker + + def worker_run_in_subprocess( + address: str, supervisor_address: str, logging_conf: Optional[Dict] = None + ) -> multiprocessing.Process: + p = multiprocessing.Process( + target=_start_worker, + args=(address, supervisor_address, None, None, logging_conf), + ) + p.start() + return p + + # start supervisor + web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port() + supervisor_address = f"127.0.0.1:{supervisor_port}" + proc_supervisor = supervisor_run_in_subprocess(supervisor_address) + rest_api_proc = multiprocessing.Process( + target=restful_api.run, + kwargs=dict( + supervisor_address=supervisor_address, host="127.0.0.1", port=web_port + ), + ) + rest_api_proc.start() + + time.sleep(5) + + # start worker + proc_worker = worker_run_in_subprocess( + address=f"127.0.0.1:{xo.utils.get_next_port()}", + supervisor_address=supervisor_address, + ) + + time.sleep(10) + + client = Client(f"http://127.0.0.1:{web_port}") + + try: + model_uid = "qwen1.5-chat" + client.launch_model( + model_uid=model_uid, + model_name="qwen1.5-chat", + model_size_in_billions="0_5", + quantization="q4_0", + model_engine="llama.cpp", + ) + + # query replica info + model_replica_info = client.describe_model(model_uid) + assert model_replica_info is not None + + # kill supervisor + proc_supervisor.terminate() + proc_supervisor.join() + + time.sleep(5) + + # restart supervisor + supervisor_run_in_subprocess(supervisor_address) + + time.sleep(5) + + # check replica info + model_replic_info_check = client.describe_model(model_uid) + assert model_replica_info["replica"] == model_replic_info_check["replica"] + + finally: + client.abort_cluster() + proc_supervisor.terminate() + proc_worker.terminate() + proc_supervisor.join() + proc_worker.join() diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 2a380bdf41..368b1238ca 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -336,6 +336,18 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType: await self._supervisor_ref.add_worker(self.address) logger.info("Connected to supervisor as a fresh worker") + # Reconnect to Newly started supervisor, has running models + if add_worker and len(self._model_uid_to_model) > 0: + # Reconnect to Newly started supervisor, notify supervisor + await self._supervisor_ref.add_worker(self.address) + # Sync replica model infos + running_models = {} + running_models.update(await self.list_models()) + await self._supervisor_ref.sync_models(self.address, running_models) + logger.info( + f"Connected to supervisor as a old worker with {len(running_models)} models" + ) + self._status_guard_ref = await xo.actor_ref( address=self._supervisor_address, uid=StatusGuardActor.default_uid() ) @@ -1049,6 +1061,8 @@ async def _periodical_report_status(self): except ( Exception ) as ex: # pragma: no cover # noqa: E722 # nosec # pylint: disable=bare-except + # Disconnect from supervisor, which maybe restart + self._supervisor_ref = None logger.error(f"Failed to upload node info: {ex}") try: await asyncio.sleep(XINFERENCE_HEALTH_CHECK_INTERVAL) diff --git a/xinference/deploy/supervisor.py b/xinference/deploy/supervisor.py index ed12a9f7c2..41a4381b54 100644 --- a/xinference/deploy/supervisor.py +++ b/xinference/deploy/supervisor.py @@ -33,7 +33,8 @@ async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None): - logging.config.dictConfig(logging_conf) # type: ignore + if logging_conf: + logging.config.dictConfig(logging_conf) # type: ignore pool = None try: