Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye committed Dec 12, 2024
1 parent 1ffb392 commit d3217f1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
83 changes: 47 additions & 36 deletions xinference/core/tests/test_restart_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import time
import multiprocessing
from typing import Dict, Optional

import pytest
import xoscar as xo

from ...client import Client
from ...api import restful_api
from ...core.supervisor import SupervisorActor


# test restart supervisor
@pytest.mark.asyncio
async def test_restart_supervisor():
def test_restart_supervisor():
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
from ...deploy.worker import main as _start_worker

Expand All @@ -39,50 +38,62 @@ def worker_run_in_subprocess(
return p

# start supervisor
supervisor_address = f"localhost:{xo.utils.get_next_port()}"
web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port()
supervisor_address = f"127.0.0.1:{supervisor_port}"
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
rest_api_proc = multiprocessing.Process(
target=restful_api.run,
kwargs=dict(
supervisor_address=supervisor_address,
host="127.0.0.1",
port=web_port
)
)
rest_api_proc.start()

await asyncio.sleep(5)
time.sleep(5)

# start worker
worker_run_in_subprocess(
address=f"localhost:{xo.utils.get_next_port()}",
proc_worker = worker_run_in_subprocess(
address=f"127.0.0.1:{xo.utils.get_next_port()}",
supervisor_address=supervisor_address,
)

await asyncio.sleep(10)
time.sleep(10)

# load model
supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)
client = Client(f"http://127.0.0.1:{web_port}")

model_uid = "qwen1.5-chat"
await supervisor_ref.launch_builtin_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="llama.cpp",
)
try:
model_uid = "qwen1.5-chat"
client.launch_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="llama.cpp",
)

# query replica info
model_replica_info = await supervisor_ref.describe_model(model_uid)
# query replica info
model_replica_info = client.describe_model(model_uid)
assert model_replica_info is not None

# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()
# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()

# restart supervisor
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
# restart supervisor
supervisor_run_in_subprocess(supervisor_address)

await asyncio.sleep(5)
time.sleep(5)

supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)
# check replica info
model_replic_info_check = client.describe_model(model_uid)
assert model_replica_info["replica"] == model_replic_info_check["replica"]

# check replica info
model_replic_info_check = await supervisor_ref.describe_model(model_uid)
finally:
client.abort_cluster()
proc_supervisor.terminate()
proc_worker.terminate()
proc_supervisor.join()
proc_worker.join()

assert model_replica_info["replica"] == model_replic_info_check["replica"]
3 changes: 2 additions & 1 deletion xinference/deploy/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@


async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
logging.config.dictConfig(logging_conf) # type: ignore
if logging_conf:
logging.config.dictConfig(logging_conf) # type: ignore

pool = None
try:
Expand Down

0 comments on commit d3217f1

Please sign in to comment.