From de5bb24852008288e049e5fb2b3c776de03c6e6c Mon Sep 17 00:00:00 2001 From: Sertac Ozercan Date: Wed, 27 Dec 2023 21:17:17 +0000 Subject: [PATCH] tests: add basic tests Signed-off-by: Sertac Ozercan --- pkg/aikit/config/specs_test.go | 69 ++++++++++++++++++ pkg/aikit2llb/convert_test.go | 40 +++++++++++ pkg/build/build.go | 7 ++ pkg/build/build_test.go | 126 +++++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+) create mode 100644 pkg/aikit/config/specs_test.go create mode 100644 pkg/aikit2llb/convert_test.go create mode 100644 pkg/build/build_test.go diff --git a/pkg/aikit/config/specs_test.go b/pkg/aikit/config/specs_test.go new file mode 100644 index 00000000..3c81cc7a --- /dev/null +++ b/pkg/aikit/config/specs_test.go @@ -0,0 +1,69 @@ +package config + +import ( + "reflect" + "testing" + + "github.com/sozercan/aikit/pkg/utils" +) + +func TestNewFromBytes(t *testing.T) { + type args struct { + b []byte + } + tests := []struct { + name string + args args + want *Config + wantErr bool + }{ + { + name: "valid yaml", + args: args{b: []byte(` +apiVersion: v1alpha1 +runtime: avx512 +backends: +- exllama +- stablediffusion +models: +- name: test + source: foo +`)}, + want: &Config{ + APIVersion: "v1alpha1", + Runtime: utils.RuntimeCPUAVX512, + Backends: []string{ + utils.BackendExllama, + utils.BackendStableDiffusion, + }, + Models: []Model{ + { + Name: "test", + Source: "foo", + }, + }, + }, + wantErr: false, + }, + { + name: "invalid yaml", + args: args{b: []byte(` +foo +`)}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewFromBytes(tt.args.b) + if (err != nil) != tt.wantErr { + t.Errorf("NewFromBytes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewFromBytes() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/aikit2llb/convert_test.go b/pkg/aikit2llb/convert_test.go new file mode 100644 index 00000000..0cde7d3a --- /dev/null +++ b/pkg/aikit2llb/convert_test.go @@ -0,0 +1,40 @@ +package aikit2llb + +import ( + "testing" +) + +func Test_fileNameFromURL(t *testing.T) { + type args struct { + urlString string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "simple", + args: args{urlString: "http://foo.bar/baz"}, + want: "baz", + }, + { + name: "complex", + args: args{urlString: "http://foo.bar/baz.tar.gz"}, + want: "baz.tar.gz", + }, + { + name: "complex with path", + args: args{urlString: "http://foo.bar/baz.tar.gz?foo=bar"}, + want: "baz.tar.gz", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := fileNameFromURL(tt.args.urlString); got != tt.want { + t.Errorf("fileNameFromURL() = %v, want %v", got, tt.want) + } + }) + } +} + diff --git a/pkg/build/build.go b/pkg/build/build.go index 79a2b1a4..ee157936 100644 --- a/pkg/build/build.go +++ b/pkg/build/build.go @@ -140,6 +140,13 @@ func validateConfig(c *config.Config) error { return errors.New("exllama only supports nvidia cuda runtime. please add 'runtime: cuda' to your aikitfile.yaml") } + backends := []string{utils.BackendExllama, utils.BackendExllamaV2, utils.BackendStableDiffusion} + for _, b := range c.Backends { + if !slices.Contains(backends, b) { + return errors.Errorf("backend %s is not supported", b) + } + } + runtimes := []string{"", utils.RuntimeNVIDIA, utils.RuntimeCPUAVX, utils.RuntimeCPUAVX2, utils.RuntimeCPUAVX512} if !slices.Contains(runtimes, c.Runtime) { return errors.Errorf("runtime %s is not supported", c.Runtime) diff --git a/pkg/build/build_test.go b/pkg/build/build_test.go new file mode 100644 index 00000000..ec3900ec --- /dev/null +++ b/pkg/build/build_test.go @@ -0,0 +1,126 @@ +package build + +import ( + "testing" + + "github.com/sozercan/aikit/pkg/aikit/config" +) + +func Test_validateConfig(t *testing.T) { + type args struct { + c *config.Config + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "no config", + args: args{c: &config.Config{}}, + wantErr: true, + }, + { + name: "unsupported api version", + args: args{c: &config.Config{ + APIVersion: "v10", + }}, + wantErr: true, + }, + { + name: "invalid runtime", + args: args{c: &config.Config{ + APIVersion: "v1", + Runtime: "foo", + }}, + wantErr: true, + }, + { + name: "no models", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + }}, + wantErr: true, + }, + { + name: "valid backend", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + Runtime: "cuda", + Backends: []string{"exllama"}, + Models: []config.Model{ + { + Name: "test", + Source: "foo", + }, + }, + }}, + wantErr: false, + }, + { + name: "invalid backend", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + Backends: []string{"foo"}, + Models: []config.Model{ + { + Name: "test", + Source: "foo", + }, + }, + }}, + wantErr: true, + }, + { + name: "valid backend but no cuda runtime", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + Backends: []string{"exllama"}, + Models: []config.Model{ + { + Name: "test", + Source: "foo", + }, + }, + }}, + wantErr: true, + }, + { + name: "invalid backend combination 1", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + Runtime: "cuda", + Backends: []string{"exllama", "exllama2"}, + Models: []config.Model{ + { + Name: "test", + Source: "foo", + }, + }, + }}, + wantErr: true, + }, + { + name: "invalid backend combination 2", + args: args{c: &config.Config{ + APIVersion: "v1alpha1", + Runtime: "cuda", + Backends: []string{"exllama", "stablediffusion"}, + Models: []config.Model{ + { + Name: "test", + Source: "foo", + }, + }, + }}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateConfig(tt.args.c); (err != nil) != tt.wantErr { + t.Errorf("validateConfig() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}