Skip to content

Commit

Permalink
add scale workers tab
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyda Cinarel committed Oct 15, 2020
1 parent 18da1be commit f99c036
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
16 changes: 16 additions & 0 deletions torchserve_dashboard/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import streamlit as st


def raise_on_not200(response):
if response.status_code != 200:
st.write("There was an error!")
Expand Down Expand Up @@ -98,3 +99,18 @@ def change_model_default(M_API, model_name, version):
req_url += "/set-default"
res = client.put(req_url)
return res.json()


def change_model_workers(M_API, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None):
req_url = M_API + "/models/" + model_name
if version:
req_url += "/" + version
req_url += "?synchronous=false"
if min_worker:
req_url += "&min_worker=" + str(min_worker)
if max_worker:
req_url += "&max_worker=" + str(max_worker)
if number_gpu:
req_url += "&number_gpu=" + str(number_gpu)
res = client.put(req_url)
return res.json()
32 changes: 31 additions & 1 deletion torchserve_dashboard/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def get_model_store():
if proceed and model_name != default_key and version != default_key:
res = tsa.delete_model(M_API, model_name, version)
last_res()[0] = res
proceed=False
rerun()

with st.beta_expander(label="Get model details", expanded=False):
Expand All @@ -164,3 +163,34 @@ def get_model_store():
elif version != default_key:
res = tsa.get_model(M_API, model_name, version)
st.write(res)

with st.beta_expander(label="Scale workers", expanded=False):
st.markdown("# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)")
model_name = st.selectbox("Pick model", [default_key] + loaded_models_names, index=0)
if model_name != default_key:
default_version = tsa.get_model(M_API,model_name)[0]["modelVersion"]
st.write(f"default version {default_version}")
versions = tsa.get_model(M_API,model_name, list_all=False)
versions = [m["modelVersion"] for m in versions]
version = st.selectbox("Choose version", ["All"] + versions, index=0)

col1, col2, col3 = st.beta_columns(3)
min_worker = col1.number_input(label="min_worker(optional)", value=-1, min_value=-1, step=1)
max_worker = col2.number_input(label="max_worker(optional)", value=-1, min_value=-1, step=1)
number_gpu = col3.number_input(label="number_gpu(optional)", value=-1, min_value=-1, step=1)
proceed = st.button("Apply")
if proceed and model_name != default_key:
# number_input can't be set to None
if version == "All":
version=None
if min_worker == -1:
min_worker=None
if max_worker == -1:
max_worker=None
if number_gpu == -1:
number_gpu=None

res = tsa.change_model_workers(M_API, model_name, version=version, min_worker=min_worker, max_worker=max_worker, number_gpu=number_gpu)
last_res()[0] = res
rerun()

0 comments on commit f99c036

Please sign in to comment.