Skip to content

Commit

Permalink
chore: support downloading files from the model repository
Browse files Browse the repository at this point in the history
  • Loading branch information
0xff-dev committed Mar 4, 2024
1 parent a7df03c commit c24392b
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 32 deletions.
4 changes: 4 additions & 0 deletions api/base/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type ModelSpec struct {
// ModelScopeRepo defines the modelscope repo which hosts this model
ModelScopeRepo string `json:"modelScopeRepo,omitempty"`

// Revision it's required if download model file from modelscope
// It can be a tag, branch name.
Revision string `json:"revision,omitempty"`

// MaxContextLength defines the max context length allowed in this model
MaxContextLength int `json:"maxContextLength,omitempty"`
}
Expand Down
4 changes: 4 additions & 0 deletions config/crd/bases/arcadia.kubeagi.k8s.com.cn_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ spec:
description: ModelScopeRepo defines the modelscope repo which hosts
this model
type: string
revision:
description: Revision it's required if download model file from modelscope
It can be a tag, branch name.
type: string
source:
description: Source define the source of the model file
properties:
Expand Down
2 changes: 1 addition & 1 deletion controllers/base/model_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (r *ModelReconciler) CheckModel(ctx context.Context, logger logr.Logger, in

// If source is empty, it means that the data is still sourced from the internal minio and a state check is required,
// otherwise we consider the model file for the trans-core service to be ready.
if instance.Spec.Source == nil {
if instance.Spec.Source == nil && (instance.Spec.HuggingFaceRepo == "" && instance.Spec.ModelScopeRepo == "") {
logger.V(5).Info(fmt.Sprintf("model %s source is empty, check minio status.", instance.Name))
system, err := config.GetSystemDatasource(ctx, r.Client)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ spec:
description: ModelScopeRepo defines the modelscope repo which hosts
this model
type: string
revision:
description: Revision it's required if download model file from modelscope
It can be a tag, branch name.
type: string
source:
description: Source define the source of the model file
properties:
Expand Down
2 changes: 1 addition & 1 deletion deploy/llms/Dockerfile.fastchat-worker
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN export DEBIAN_FRONTEND=noninteractive \
ARG PYTHON_INDEX_URL="https://pypi.mirrors.ustc.edu.cn/simple/"

# Install fastchat along with its dependencies
RUN apt-get install -y python3.9 python3.9-distutils curl python3-pip python3-dev
RUN apt-get install -y python3.9 python3.9-distutils curl python3-pip python3-dev gcc
RUN python3.9 -m pip install tomli setuptools_scm wavedrom transformers==4.37.0 -i ${PYTHON_INDEX_URL}
RUN python3.9 -m pip install --upgrade pip -i ${PYTHON_INDEX_URL}
RUN git clone https://github.com/lm-sys/FastChat.git \
Expand Down
5 changes: 5 additions & 0 deletions pkg/worker/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
Expand Down Expand Up @@ -78,6 +79,10 @@ func (loader *LoaderOSS) Build(ctx context.Context, model *arcadiav1alpha1.Typed
Object: fmt.Sprintf("model/%s/", model.Name),
})
if err != nil {
if err == datasource.ErrOSSNoSuchObject {
klog.Info("No object was found, So it could pull the model file from other places.")
return nil, nil
}
return nil, err
}

Expand Down
65 changes: 57 additions & 8 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strconv"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand All @@ -31,8 +32,9 @@ import (
)

const (
defaultFastChatImage = "kubeagi/arcadia-fastchat-worker:v0.2.0"
defaultFastchatVLLMImage = "kubeagi/arcadia-fastchat-worker:vllm-v0.2.0"
// tag is the same version as fastchat
defaultFastChatImage = "kubeagi/arcadia-fastchat-worker:v0.2.36"
defaultFastchatVLLMImage = "kubeagi/arcadia-fastchat-worker:vllm-v0.2.36"
)

// ModelRunner run a model service
Expand All @@ -49,12 +51,15 @@ var _ ModelRunner = (*RunnerFastchat)(nil)
type RunnerFastchat struct {
c client.Client
w *arcadiav1alpha1.Worker

modelFileFromRemote bool
}

func NewRunnerFastchat(c client.Client, w *arcadiav1alpha1.Worker) (ModelRunner, error) {
func NewRunnerFastchat(c client.Client, w *arcadiav1alpha1.Worker, modelFileFromRemote bool) (ModelRunner, error) {
return &RunnerFastchat{
c: c,
w: w,
c: c,
w: w,
modelFileFromRemote: modelFileFromRemote,
}, nil
}

Expand All @@ -77,6 +82,25 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
return nil, fmt.Errorf("failed to get arcadia config with %w", err)
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
extraArgs := fmt.Sprintf("--device %s", runner.Device().String())
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"})
extraArgs += fmt.Sprintf(" --revision %s ", m.Spec.Revision)
}
}

additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_MODEL_NAME_PATH", Value: modelFileDir})
img := defaultFastChatImage
if runner.w.Spec.Runner.Image != "" {
img = runner.w.Spec.Runner.Image
Expand All @@ -94,7 +118,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
{Name: "FASTCHAT_WORKER_ADDRESS", Value: fmt.Sprintf("http://%s.%s:21002", runner.w.Name+WokerCommonSuffix, runner.w.Namespace)},
{Name: "FASTCHAT_CONTROLLER_ADDRESS", Value: gw.Controller},
{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()},
{Name: "EXTRA_ARGS", Value: fmt.Sprintf("--device %s", runner.Device().String())},
{Name: "EXTRA_ARGS", Value: extraArgs},
},
Ports: []corev1.ContainerPort{
{Name: "http", ContainerPort: 21002},
Expand All @@ -105,6 +129,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
Resources: runner.w.Spec.Resources,
}

container.Env = append(container.Env, additionalEnvs...)
return container, nil
}

Expand All @@ -114,12 +139,16 @@ var _ ModelRunner = (*RunnerFastchatVLLM)(nil)
type RunnerFastchatVLLM struct {
c client.Client
w *arcadiav1alpha1.Worker

modelFileFromRemote bool
}

func NewRunnerFastchatVLLM(c client.Client, w *arcadiav1alpha1.Worker) (ModelRunner, error) {
func NewRunnerFastchatVLLM(c client.Client, w *arcadiav1alpha1.Worker, modelFileFromRemote bool) (ModelRunner, error) {
return &RunnerFastchatVLLM{
c: c,
w: w,

modelFileFromRemote: modelFileFromRemote,
}, nil
}

Expand Down Expand Up @@ -175,6 +204,25 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
klog.Infof("run worker with %s GPU", runner.NumberOfGPUs())
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
extraAgrs := "--trust-remote-code"
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
modelFileDir = m.Spec.HuggingFaceRepo
}
if m.Spec.ModelScopeRepo != "" {
modelFileDir = m.Spec.ModelScopeRepo
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_USE_MODELSCOPE", Value: "True"}, corev1.EnvVar{Name: "VLLM_USE_MODELSCOPE", Value: "True"})
extraAgrs += fmt.Sprintf(" --revision %s", m.Spec.Revision)
}
}

additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "FASTCHAT_MODEL_NAME_PATH", Value: modelFileDir})
img := defaultFastchatVLLMImage
if runner.w.Spec.Runner.Image != "" {
img = runner.w.Spec.Runner.Image
Expand All @@ -190,7 +238,7 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
{Name: "FASTCHAT_MODEL_NAME", Value: model.Name},
{Name: "FASTCHAT_WORKER_ADDRESS", Value: fmt.Sprintf("http://%s.%s:21002", runner.w.Name+WokerCommonSuffix, runner.w.Namespace)},
{Name: "FASTCHAT_CONTROLLER_ADDRESS", Value: gw.Controller},
{Name: "EXTRA_ARGS", Value: "--trust-remote-code"},
{Name: "EXTRA_ARGS", Value: extraAgrs},
// Need python version and ray address for distributed inference
{Name: "PYTHON_VERSION", Value: pythonVersion},
{Name: "RAY_ADDRESS", Value: rayClusterAddress},
Expand All @@ -203,6 +251,7 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
},
Resources: runner.w.Spec.Resources,
}
container.Env = append(container.Env, additionalEnvs...)

return container, nil
}
45 changes: 23 additions & 22 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,6 @@ func NewPodWorker(ctx context.Context, c client.Client, s *runtime.Scheme, w *ar
}

// init runner
switch w.Type() {
case arcadiav1alpha1.WorkerTypeFastchatVLLM:
r, err := NewRunnerFastchatVLLM(c, w.DeepCopy())
if err != nil {
return nil, fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
case arcadiav1alpha1.WorkerTypeFastchatNormal:
r, err := NewRunnerFastchat(c, w.DeepCopy())
if err != nil {
return nil, fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
default:
return nil, fmt.Errorf("worker %s with type %s not supported in worker", w.Name, w.Type())
}

return podWorker, nil
}

Expand Down Expand Up @@ -391,7 +374,22 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to build loader with %w", err)
}
conLoader, _ := loader.(*corev1.Container)
switch podWorker.w.Type() {
case arcadiav1alpha1.WorkerTypeFastchatVLLM:
r, err := NewRunnerFastchatVLLM(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
if err != nil {
return fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
case arcadiav1alpha1.WorkerTypeFastchatNormal:
r, err := NewRunnerFastchat(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
if err != nil {
return fmt.Errorf("failed to new a runner with %w", err)
}
podWorker.r = r
default:
return fmt.Errorf("worker %s with type %s not supported in worker", podWorker.w.Name, podWorker.w.Type())
}

// define the way to run model
runner, err := podWorker.r.Build(ctx, &arcadiav1alpha1.TypedObjectReference{Namespace: &podWorker.m.Namespace, Name: podWorker.m.Name})
Expand Down Expand Up @@ -422,12 +420,15 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicyAlways,
InitContainers: []corev1.Container{*conLoader},
Containers: []corev1.Container{*conRunner},
Volumes: []corev1.Volume{podWorker.storage},
RestartPolicy: corev1.RestartPolicyAlways,
Containers: []corev1.Container{*conRunner},
Volumes: []corev1.Volume{podWorker.storage},
},
}
if loader != nil {
conLoader, _ := loader.(*corev1.Container)
podSpecTemplate.Spec.InitContainers = []corev1.Container{*conLoader}
}
if podWorker.storage.HostPath != nil {
podSpecTemplate.Spec.Affinity = &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
Expand Down

0 comments on commit c24392b

Please sign in to comment.