diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index ce31434d4..f139b9df6 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -218,7 +218,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex return nil, errors.New("dotprompt model not in provider/name format") } - model := ai.LookupModel(provider, name) + model = ai.LookupModel(provider, name) if model == nil { return nil, fmt.Errorf("no model named %q for provider %q", name, provider) } diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 71bb3bb22..a724d28d8 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -43,14 +43,31 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context. func TestExecute(t *testing.T) { testModel := ai.DefineModel("test", "test", nil, testGenerate) - p, err := New("TestExecute", "TestExecute", Config{Model: testModel}) - if err != nil { - t.Fatal(err) - } - resp, err := p.Generate(context.Background(), &PromptRequest{}, nil) - if err != nil { - t.Fatal(err) - } + t.Run("Model", func(t *testing.T) { + p, err := New("TestExecute", "TestExecute", Config{Model: testModel}) + if err != nil { + t.Fatal(err) + } + resp, err := p.Generate(context.Background(), &PromptRequest{}, nil) + if err != nil { + t.Fatal(err) + } + assertResponse(t, resp) + }) + t.Run("ModelName", func(t *testing.T) { + p, err := New("TestExecute", "TestExecute", Config{ModelName: "test/test"}) + if err != nil { + t.Fatal(err) + } + resp, err := p.Generate(context.Background(), &PromptRequest{}, nil) + if err != nil { + t.Fatal(err) + } + assertResponse(t, resp) + }) +} + +func assertResponse(t *testing.T, resp *ai.GenerateResponse) { if len(resp.Candidates) != 1 { t.Errorf("got %d candidates, want 1", len(resp.Candidates)) if len(resp.Candidates) < 1 {