Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent Service #4564

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 154 additions & 13 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
TAG_POD_INITIALIZED = 'skypilot-initialized'


def _is_head(pod) -> bool:
return pod.metadata.labels.get(constants.TAG_RAY_NODE_KIND) == 'head'


def _get_head_pod_name(pods: Dict[str, Any]) -> Optional[str]:
head_pod_name = None
for pod_name, pod in pods.items():
if pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head':
head_pod_name = pod_name
break
return head_pod_name
return next((pod_name for pod_name, pod in pods.items() if _is_head(pod)),
None)


def head_service_selector(cluster_name: str) -> Dict[str, str]:
Expand Down Expand Up @@ -650,6 +650,130 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict,
raise e


def _is_serve_controller(cluster_name_on_cloud: str) -> bool:
return cluster_name_on_cloud.startswith('sky-serve-controller-')


def _create_persistent_volume_claim(namespace: str, context: Optional[str],
pvc_name: str) -> None:
"""Creates a persistent volume claim for SkyServe controller."""
try:
kubernetes.core_api(context).read_namespaced_persistent_volume_claim(
name=pvc_name, namespace=namespace)
return
except kubernetes.api_exception() as e:
if e.status != 404: # Not found
raise

pvc_spec = {
'apiVersion': 'v1',
'kind': 'PersistentVolumeClaim',
'metadata': {
'name': pvc_name,
},
'spec': {
'accessModes': ['ReadWriteOnce'],
'resources': {
'requests': {
'storage': '10Gi' # TODO(andyl): use a constant here
}
}
}
}

kubernetes.core_api(context).create_namespaced_persistent_volume_claim(
namespace=namespace, body=pvc_spec)


def _create_serve_controller_deployment(
pod_spec: Dict[str, Any], cluster_name_on_cloud: str, namespace: str,
context: Optional[str]) -> Dict[str, Any]:
"""Creates a deployment for SkyServe controller with persistence."""
pvc_name = f'{cluster_name_on_cloud}-data'
_create_persistent_volume_claim(namespace, context, pvc_name)

# The reason we mount the whole /home/sky/.sky instead of just
# /home/sky/.sky/serve is that k8s changes the ownership of the
# mounted directory to root:root. If we only mount /home/sky/.sky/serve,
# the serve controller will not be able to create the serve directory.
# pylint: disable=line-too-long
# See https://stackoverflow.com/questions/50818029/mounted-folder-created-as-root-instead-of-current-user-in-docker/50820023#50820023.
mount_path = '/home/sky/.sky' # TODO(andyl): use a constant here
volume_mounts = [{'name': 'serve-data', 'mountPath': mount_path}]

volumes = [{
'name': 'serve-data',
'persistentVolumeClaim': {
'claimName': pvc_name
}
}]

if 'volumes' in pod_spec['spec']:
pod_spec['spec']['volumes'].extend(volumes)
else:
pod_spec['spec']['volumes'] = volumes

for container in pod_spec['spec']['containers']:
if 'volumeMounts' in container:
container['volumeMounts'].extend(volume_mounts)
else:
container['volumeMounts'] = volume_mounts

template_metadata = pod_spec.pop('metadata')

deployment_labels = {
'app': cluster_name_on_cloud,
}
template_metadata['labels'].update(deployment_labels)

# The pod template part of pod_spec is used in the deployment
# spec.template.spec

deployment_spec = {
'apiVersion': 'apps/v1',
'kind': 'Deployment',
'metadata': {
'name': f'{cluster_name_on_cloud}-deployment',
'namespace': namespace,
},
'spec': {
'replicas': 1,
'selector': {
'matchLabels': deployment_labels
},
'template': {
'metadata': template_metadata,
'spec': {
**pod_spec['spec'], 'restartPolicy': 'Always'
}
}
}
}

return deployment_spec


@timeline.event
def _wait_for_deployment_pod(context, namespace, deployment, timeout=60):
label_selector = ','.join([
f'{key}={value}'
for key, value in deployment.spec.selector.match_labels.items()
])
target_replicas = deployment.spec.replicas
start_time = time.time()
while time.time() - start_time < timeout:
pods = kubernetes.core_api(context).list_namespaced_pod(
namespace, label_selector=label_selector).items
# TODO(andyl): not sure if this necessary
if len(pods) == target_replicas:
return pods
time.sleep(2)

raise TimeoutError(
f'Timeout: Not all Pods for Deployment {deployment.metadata.name!r}'
' are created.')


@timeline.event
def _create_pods(region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
Expand All @@ -661,6 +785,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
tags = {
TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
}

pod_spec['metadata']['namespace'] = namespace
if 'labels' in pod_spec['metadata']:
pod_spec['metadata']['labels'].update(tags)
Expand Down Expand Up @@ -748,7 +873,8 @@ def _create_pod_thread(i: int):
pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS)
head_selector = head_service_selector(cluster_name_on_cloud)
pod_spec_copy['metadata']['labels'].update(head_selector)
pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head'
pod_spec_copy['metadata'][
'name'] = f'{cluster_name_on_cloud}-head' #!
else:
# Worker pods
pod_spec_copy['metadata']['labels'].update(
Expand Down Expand Up @@ -799,18 +925,36 @@ def _create_pod_thread(i: int):
}
pod_spec_copy['spec']['tolerations'] = [tpu_toleration]

if _is_serve_controller(cluster_name_on_cloud):
deployment_spec = _create_serve_controller_deployment(
pod_spec_copy, cluster_name_on_cloud, namespace, context)
print('try to create deployment')
try:
return kubernetes.apps_api(
context).create_namespaced_deployment(
namespace, deployment_spec)
except Exception as e:
print('Deployment failed', e)
raise e

return _create_namespaced_pod_with_retries(namespace, pod_spec_copy,
context)

# Create pods in parallel
pods = subprocess_utils.run_in_parallel(_create_pod_thread,
range(to_start_count), _NUM_THREADS)

if _is_serve_controller(cluster_name_on_cloud):
deployments = copy.deepcopy(pods)
pods.clear() # Since it's not pods. What created here are true pods.
for deployment in deployments:
pods.extend(_wait_for_deployment_pod(context, namespace,
deployment))

# Process created pods
for pod in pods:
created_pods[pod.metadata.name] = pod
if head_pod_name is None and pod.metadata.labels.get(
constants.TAG_RAY_NODE_KIND) == 'head':
if head_pod_name is None and _is_head(pod):
head_pod_name = pod.metadata.name

networking_mode = network_utils.get_networking_mode(
Expand Down Expand Up @@ -964,9 +1108,6 @@ def terminate_instances(
logger.warning('terminate_instances: Error occurred when analyzing '
f'SSH Jump pod: {e}')

def _is_head(pod) -> bool:
return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head'

def _terminate_pod_thread(pod_info):
pod_name, pod = pod_info
if _is_head(pod) and worker_only:
Expand Down Expand Up @@ -1022,7 +1163,7 @@ def get_cluster_info(
tags=pod.metadata.labels,
)
]
if pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head':
if _is_head(pod):
head_pod_name = pod_name
head_spec = pod.spec
assert head_spec is not None, pod
Expand Down
22 changes: 22 additions & 0 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,25 @@ def delete_all_versions(service_name: str) -> None:
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))


def get_service_controller_port(service_name: str) -> int:
"""Gets the controller port of a service."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute('SELECT controller_port FROM services WHERE name = ?',
(service_name,))
row = cursor.fetchone()
if row is None:
raise ValueError(f'Service {service_name} does not exist.')
return row[0]


def get_service_load_balancer_port(service_name: str) -> int:
"""Gets the load balancer port of a service."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute('SELECT load_balancer_port FROM services WHERE name = ?',
(service_name,))
row = cursor.fetchone()
if row is None:
raise ValueError(f'Service {service_name} does not exist.')
return row[0]
86 changes: 57 additions & 29 deletions sky/serve/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ def cleanup_version_storage(version: int) -> bool:
return failed


def is_recovery_mode(service_name: str) -> bool:
"""Check if service exists in database to determine recovery mode.

Args:
service_name: Name of the service to check

Returns:
True if service exists in database, indicating recovery mode
"""
service = serve_state.get_service_from_name(service_name)
return service is not None


def _start(service_name: str, tmp_task_yaml: str, job_id: int):
"""Starts the service."""
# Generate ssh key pair to avoid race condition when multiple sky.launch
Expand All @@ -141,27 +154,32 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
# Already checked before submit to controller.
assert task.service is not None, task
service_spec = task.service
if len(serve_state.get_services()) >= serve_utils.NUM_SERVICE_THRESHOLD:
cleanup_storage(tmp_task_yaml)
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Max number of services reached.')
success = serve_state.add_service(
service_name,
controller_job_id=job_id,
policy=service_spec.autoscaling_policy_str(),
requested_resources_str=backend_utils.get_task_resources_str(task),
load_balancing_policy=service_spec.load_balancing_policy,
status=serve_state.ServiceStatus.CONTROLLER_INIT)
# Directly throw an error here. See sky/serve/api.py::up
# for more details.
if not success:
cleanup_storage(tmp_task_yaml)
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Service {service_name} already exists.')

# Add initial version information to the service state.
serve_state.add_or_update_version(service_name, constants.INITIAL_VERSION,
service_spec)

is_recovery = is_recovery_mode(service_name)

if not is_recovery:
if len(serve_state.get_services()) >= serve_utils.NUM_SERVICE_THRESHOLD:
cleanup_storage(tmp_task_yaml)
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Max number of services reached.')
success = serve_state.add_service(
service_name,
controller_job_id=job_id,
policy=service_spec.autoscaling_policy_str(),
requested_resources_str=backend_utils.get_task_resources_str(task),
load_balancing_policy=service_spec.load_balancing_policy,
status=serve_state.ServiceStatus.CONTROLLER_INIT)
# Directly throw an error here. See sky/serve/api.py::up
# for more details.
if not success:
cleanup_storage(tmp_task_yaml)
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Service {service_name} already exists.')

# Add initial version information to the service state.
serve_state.add_or_update_version(service_name,
constants.INITIAL_VERSION,
service_spec)

# Create the service working directory.
service_dir = os.path.expanduser(
Expand All @@ -187,8 +205,15 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
try:
with filelock.FileLock(
os.path.expanduser(constants.PORT_SELECTION_FILE_LOCK_PATH)):
controller_port = common_utils.find_free_port(
constants.CONTROLLER_PORT_START)
if is_recovery:
# In recovery mode, use the ports from the database
controller_port = serve_state.get_service_controller_port(
service_name)
load_balancer_port = serve_state.get_service_load_balancer_port(
service_name)
else:
controller_port = common_utils.find_free_port(
constants.CONTROLLER_PORT_START)

# We expose the controller to the public network when running
# inside a kubernetes cluster to allow external load balancers
Expand All @@ -211,14 +236,16 @@ def _get_host():
args=(service_name, service_spec, task_yaml, controller_host,
controller_port))
controller_process.start()
serve_state.set_service_controller_port(service_name,
controller_port)
if not is_recovery:
serve_state.set_service_controller_port(service_name,
controller_port)

# TODO(tian): Support HTTPS.
controller_addr = f'http://{controller_host}:{controller_port}'

load_balancer_port = common_utils.find_free_port(
constants.LOAD_BALANCER_PORT_START)
if not is_recovery:
load_balancer_port = common_utils.find_free_port(
constants.LOAD_BALANCER_PORT_START)

# Extract the load balancing policy from the service spec
policy_name = service_spec.load_balancing_policy
Expand All @@ -233,8 +260,9 @@ def _get_host():
load_balancer_log_file).run,
args=(controller_addr, load_balancer_port, policy_name))
load_balancer_process.start()
serve_state.set_service_load_balancer_port(service_name,
load_balancer_port)
if not is_recovery:
serve_state.set_service_load_balancer_port(
service_name, load_balancer_port)

while True:
_handle_signal(service_name)
Expand Down
Loading