Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: support downloading files from the model repository #787

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/base/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type ModelSpec struct {
// ModelScopeRepo defines the modelscope repo which hosts this model
ModelScopeRepo string `json:"modelScopeRepo,omitempty"`

// Revision it's required if download model file from modelscope
0xff-dev marked this conversation as resolved.
Show resolved Hide resolved
// It can be a tag, branch name.
Revision string `json:"revision,omitempty"`

// MaxContextLength defines the max context length allowed in this model
MaxContextLength int `json:"maxContextLength,omitempty"`
}
Expand Down
4 changes: 4 additions & 0 deletions config/crd/bases/arcadia.kubeagi.k8s.com.cn_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ spec:
description: ModelScopeRepo defines the modelscope repo which hosts
this model
type: string
revision:
description: Revision it's required if download model file from modelscope
It can be a tag, branch name.
type: string
source:
description: Source define the source of the model file
properties:
Expand Down
2 changes: 1 addition & 1 deletion controllers/base/model_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (r *ModelReconciler) CheckModel(ctx context.Context, logger logr.Logger, in

// If source is empty, it means that the data is still sourced from the internal minio and a state check is required,
// otherwise we consider the model file for the trans-core service to be ready.
if instance.Spec.Source == nil {
if instance.Spec.Source == nil && (instance.Spec.HuggingFaceRepo == "" && instance.Spec.ModelScopeRepo == "") {
logger.V(5).Info(fmt.Sprintf("model %s source is empty, check minio status.", instance.Name))
system, err := config.GetSystemDatasource(ctx, r.Client)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ spec:
description: ModelScopeRepo defines the modelscope repo which hosts
this model
type: string
revision:
description: Revision it's required if download model file from modelscope
It can be a tag, branch name.
type: string
source:
description: Source define the source of the model file
properties:
Expand Down
2 changes: 1 addition & 1 deletion deploy/llms/Dockerfile.fastchat-worker
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN export DEBIAN_FRONTEND=noninteractive \
ARG PYTHON_INDEX_URL="https://pypi.mirrors.ustc.edu.cn/simple/"

# Install fastchat along with its dependencies
RUN apt-get install -y python3.9 python3.9-distutils curl python3-pip python3-dev
RUN apt-get install -y python3.9 python3.9-distutils curl python3-pip python3-dev gcc
RUN python3.9 -m pip install tomli setuptools_scm wavedrom transformers==4.37.0 -i ${PYTHON_INDEX_URL}
RUN python3.9 -m pip install --upgrade pip -i ${PYTHON_INDEX_URL}
RUN git clone https://github.com/lm-sys/FastChat.git \
Expand Down
5 changes: 5 additions & 0 deletions pkg/worker/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
Expand Down Expand Up @@ -78,6 +79,10 @@ func (loader *LoaderOSS) Build(ctx context.Context, model *arcadiav1alpha1.Typed
Object: fmt.Sprintf("model/%s/", model.Name),
})
if err != nil {
if err == datasource.ErrOSSNoSuchObject {
klog.Info("No object was found, So it could pull the model file from other places.")
return nil, nil
}
return nil, err
}

Expand Down
65 changes: 57 additions & 8 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strconv"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand All @@ -31,8 +32,9 @@ import (
)

const (
defaultFastChatImage = "kubeagi/arcadia-fastchat-worker:v0.2.0"
defaultFastchatVLLMImage = "kubeagi/arcadia-fastchat-worker:vllm-v0.2.0"
// tag is the same version as fastchat
defaultFastChatImage = "kubeagi/arcadia-fastchat-worker:v0.2.36"
defaultFastchatVLLMImage = "kubeagi/arcadia-fastchat-worker:vllm-v0.2.36"
)

// ModelRunner run a model service
Expand All @@ -49,12 +51,15 @@ var _ ModelRunner = (*RunnerFastchat)(nil)
type RunnerFastchat struct {
c client.Client
w *arcadiav1alpha1.Worker

modelFileFromRemote bool
}

func NewRunnerFastchat(c client.Client, w *arcadiav1alpha1.Worker) (ModelRunner, error) {
func NewRunnerFastchat(c client.Client, w *arcadiav1alpha1.Worker, modelFileFromRemote bool) (ModelRunner, error) {
return &RunnerFastchat{
c: c,
w: w,
c: c,
w: w,
modelFileFromRemote: modelFileFromRemote,
}, nil
}

Expand All @@ -77,6 +82,25 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
return nil, fmt.Errorf("failed to get arcadia config with %w", err)
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
extraArgs := fmt.Sprintf("--device %s", runner.Device().String())
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"})
extraArgs += fmt.Sprintf(" --revision %s ", m.Spec.Revision)
}
}

additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_MODEL_NAME_PATH", Value: modelFileDir})
img := defaultFastChatImage
if runner.w.Spec.Runner.Image != "" {
img = runner.w.Spec.Runner.Image
Expand All @@ -94,7 +118,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
{Name: "FASTCHAT_WORKER_ADDRESS", Value: fmt.Sprintf("http://%s.%s:21002", runner.w.Name+WokerCommonSuffix, runner.w.Namespace)},
{Name: "FASTCHAT_CONTROLLER_ADDRESS", Value: gw.Controller},
{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()},
{Name: "EXTRA_ARGS", Value: fmt.Sprintf("--device %s", runner.Device().String())},
{Name: "EXTRA_ARGS", Value: extraArgs},
},
Ports: []corev1.ContainerPort{
{Name: "http", ContainerPort: 21002},
Expand All @@ -105,6 +129,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
Resources: runner.w.Spec.Resources,
}

container.Env = append(container.Env, additionalEnvs...)
return container, nil
}

Expand All @@ -114,12 +139,16 @@ var _ ModelRunner = (*RunnerFastchatVLLM)(nil)
type RunnerFastchatVLLM struct {
c client.Client
w *arcadiav1alpha1.Worker

modelFileFromRemote bool
}

func NewRunnerFastchatVLLM(c client.Client, w *arcadiav1alpha1.Worker) (ModelRunner, error) {
func NewRunnerFastchatVLLM(c client.Client, w *arcadiav1alpha1.Worker, modelFileFromRemote bool) (ModelRunner, error) {
return &RunnerFastchatVLLM{
c: c,
w: w,

modelFileFromRemote: modelFileFromRemote,
}, nil
}

Expand Down Expand Up @@ -175,6 +204,25 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
klog.Infof("run worker with %s GPU", runner.NumberOfGPUs())
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
extraAgrs := "--trust-remote-code"
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"}, corev1.EnvVar{Name: "VLLM_USE_MODELSCOPE", Value: "True"})
extraAgrs += fmt.Sprintf(" --revision %s", m.Spec.Revision)
}
}

additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_MODEL_NAME_PATH", Value: modelFileDir})
img := defaultFastchatVLLMImage
if runner.w.Spec.Runner.Image != "" {
img = runner.w.Spec.Runner.Image
Expand All @@ -190,7 +238,7 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
{Name: "FASTCHAT_MODEL_NAME", Value: model.Name},
{Name: "FASTCHAT_WORKER_ADDRESS", Value: fmt.Sprintf("http://%s.%s:21002", runner.w.Name+WokerCommonSuffix, runner.w.Namespace)},
{Name: "FASTCHAT_CONTROLLER_ADDRESS", Value: gw.Controller},
{Name: "EXTRA_ARGS", Value: "--trust-remote-code"},
{Name: "EXTRA_ARGS", Value: extraAgrs},
// Need python version and ray address for distributed inference
{Name: "PYTHON_VERSION", Value: pythonVersion},
{Name: "RAY_ADDRESS", Value: rayClusterAddress},
Expand All @@ -203,6 +251,7 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
},
Resources: runner.w.Spec.Resources,
}
container.Env = append(container.Env, additionalEnvs...)

return container, nil
}
45 changes: 23 additions & 22 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,6 @@ func NewPodWorker(ctx context.Context, c client.Client, s *runtime.Scheme, w *ar
}

// init runner
switch w.Type() {
case arcadiav1alpha1.WorkerTypeFastchatVLLM:
r, err := NewRunnerFastchatVLLM(c, w.DeepCopy())
if err != nil {
return nil, fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
case arcadiav1alpha1.WorkerTypeFastchatNormal:
r, err := NewRunnerFastchat(c, w.DeepCopy())
if err != nil {
return nil, fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
default:
return nil, fmt.Errorf("worker %s with type %s not supported in worker", w.Name, w.Type())
}

return podWorker, nil
}

Expand Down Expand Up @@ -391,7 +374,22 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to build loader with %w", err)
}
conLoader, _ := loader.(*corev1.Container)
switch podWorker.w.Type() {
case arcadiav1alpha1.WorkerTypeFastchatVLLM:
r, err := NewRunnerFastchatVLLM(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
if err != nil {
return fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
case arcadiav1alpha1.WorkerTypeFastchatNormal:
r, err := NewRunnerFastchat(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
if err != nil {
return fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
default:
return fmt.Errorf("worker %s with type %s not supported in worker", podWorker.w.Name, podWorker.w.Type())
}

// define the way to run model
runner, err := podWorker.r.Build(ctx, &arcadiav1alpha1.TypedObjectReference{Namespace: &podWorker.m.Namespace, Name: podWorker.m.Name})
Expand Down Expand Up @@ -422,12 +420,15 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicyAlways,
InitContainers: []corev1.Container{*conLoader},
Containers: []corev1.Container{*conRunner},
Volumes: []corev1.Volume{podWorker.storage},
RestartPolicy: corev1.RestartPolicyAlways,
Containers: []corev1.Container{*conRunner},
Volumes: []corev1.Volume{podWorker.storage},
},
}
if loader != nil {
conLoader, _ := loader.(*corev1.Container)
podSpecTemplate.Spec.InitContainers = []corev1.Container{*conLoader}
}
if podWorker.storage.HostPath != nil {
podSpecTemplate.Spec.Affinity = &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
Expand Down
Loading