diff --git a/xinference/core/tests/test_restart_supervisor.py b/xinference/core/tests/test_restart_supervisor.py index 7bb4b87249..58f8ba1b99 100644 --- a/xinference/core/tests/test_restart_supervisor.py +++ b/xinference/core/tests/test_restart_supervisor.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +import time import multiprocessing from typing import Dict, Optional -import pytest import xoscar as xo +from ...client import Client +from ...api import restful_api from ...core.supervisor import SupervisorActor -# test restart supervisor -@pytest.mark.asyncio -async def test_restart_supervisor(): +def test_restart_supervisor(): from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess from ...deploy.worker import main as _start_worker @@ -39,50 +38,62 @@ def worker_run_in_subprocess( return p # start supervisor - supervisor_address = f"localhost:{xo.utils.get_next_port()}" + web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port() + supervisor_address = f"127.0.0.1:{supervisor_port}" proc_supervisor = supervisor_run_in_subprocess(supervisor_address) + rest_api_proc = multiprocessing.Process( + target=restful_api.run, + kwargs=dict( + supervisor_address=supervisor_address, + host="127.0.0.1", + port=web_port + ) + ) + rest_api_proc.start() - await asyncio.sleep(5) + time.sleep(5) # start worker - worker_run_in_subprocess( - address=f"localhost:{xo.utils.get_next_port()}", + proc_worker = worker_run_in_subprocess( + address=f"127.0.0.1:{xo.utils.get_next_port()}", supervisor_address=supervisor_address, ) - await asyncio.sleep(10) + time.sleep(10) - # load model - supervisor_ref = await xo.actor_ref( - supervisor_address, SupervisorActor.default_uid() - ) + client = Client(f"http://127.0.0.1:{web_port}") - model_uid = "qwen1.5-chat" - await supervisor_ref.launch_builtin_model( - model_uid=model_uid, - model_name="qwen1.5-chat", - model_size_in_billions="0_5", - quantization="q4_0", - model_engine="llama.cpp", - ) + try: + model_uid = "qwen1.5-chat" + client.launch_model( + model_uid=model_uid, + model_name="qwen1.5-chat", + model_size_in_billions="0_5", + quantization="q4_0", + model_engine="llama.cpp", + ) - # query replica info - model_replica_info = await supervisor_ref.describe_model(model_uid) + # query replica info + model_replica_info = client.describe_model(model_uid) + assert model_replica_info is not None - # kill supervisor - proc_supervisor.terminate() - proc_supervisor.join() + # kill supervisor + proc_supervisor.terminate() + proc_supervisor.join() - # restart supervisor - proc_supervisor = supervisor_run_in_subprocess(supervisor_address) + # restart supervisor + supervisor_run_in_subprocess(supervisor_address) - await asyncio.sleep(5) + time.sleep(5) - supervisor_ref = await xo.actor_ref( - supervisor_address, SupervisorActor.default_uid() - ) + # check replica info + model_replic_info_check = client.describe_model(model_uid) + assert model_replica_info["replica"] == model_replic_info_check["replica"] - # check replica info - model_replic_info_check = await supervisor_ref.describe_model(model_uid) + finally: + client.abort_cluster() + proc_supervisor.terminate() + proc_worker.terminate() + proc_supervisor.join() + proc_worker.join() - assert model_replica_info["replica"] == model_replic_info_check["replica"] diff --git a/xinference/deploy/supervisor.py b/xinference/deploy/supervisor.py index ed12a9f7c2..41a4381b54 100644 --- a/xinference/deploy/supervisor.py +++ b/xinference/deploy/supervisor.py @@ -33,7 +33,8 @@ async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None): - logging.config.dictConfig(logging_conf) # type: ignore + if logging_conf: + logging.config.dictConfig(logging_conf) # type: ignore pool = None try: