Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye committed Dec 12, 2024
1 parent 1ffb392 commit e2f184e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 39 deletions.
83 changes: 45 additions & 38 deletions xinference/core/tests/test_restart_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import multiprocessing
import time
from typing import Dict, Optional

import pytest
import xoscar as xo

from ...core.supervisor import SupervisorActor
from ...api import restful_api
from ...client import Client


# test restart supervisor
@pytest.mark.asyncio
async def test_restart_supervisor():
def test_restart_supervisor():
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
from ...deploy.worker import main as _start_worker

Expand All @@ -39,50 +37,59 @@ def worker_run_in_subprocess(
return p

# start supervisor
supervisor_address = f"localhost:{xo.utils.get_next_port()}"
web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port()
supervisor_address = f"127.0.0.1:{supervisor_port}"
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
rest_api_proc = multiprocessing.Process(
target=restful_api.run,
kwargs=dict(
supervisor_address=supervisor_address, host="127.0.0.1", port=web_port
),
)
rest_api_proc.start()

await asyncio.sleep(5)
time.sleep(5)

# start worker
worker_run_in_subprocess(
address=f"localhost:{xo.utils.get_next_port()}",
proc_worker = worker_run_in_subprocess(
address=f"127.0.0.1:{xo.utils.get_next_port()}",
supervisor_address=supervisor_address,
)

await asyncio.sleep(10)

# load model
supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)
time.sleep(10)

model_uid = "qwen1.5-chat"
await supervisor_ref.launch_builtin_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="llama.cpp",
)
client = Client(f"http://127.0.0.1:{web_port}")

# query replica info
model_replica_info = await supervisor_ref.describe_model(model_uid)
try:
model_uid = "qwen1.5-chat"
client.launch_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="llama.cpp",
)

# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()
# query replica info
model_replica_info = client.describe_model(model_uid)
assert model_replica_info is not None

# restart supervisor
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()

await asyncio.sleep(5)
# restart supervisor
supervisor_run_in_subprocess(supervisor_address)

supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)
time.sleep(5)

# check replica info
model_replic_info_check = await supervisor_ref.describe_model(model_uid)
# check replica info
model_replic_info_check = client.describe_model(model_uid)
assert model_replica_info["replica"] == model_replic_info_check["replica"]

assert model_replica_info["replica"] == model_replic_info_check["replica"]
finally:
client.abort_cluster()
proc_supervisor.terminate()
proc_worker.terminate()
proc_supervisor.join()
proc_worker.join()
3 changes: 2 additions & 1 deletion xinference/deploy/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@


async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
logging.config.dictConfig(logging_conf) # type: ignore
if logging_conf:
logging.config.dictConfig(logging_conf) # type: ignore

pool = None
try:
Expand Down

0 comments on commit e2f184e

Please sign in to comment.