From bb597f72cc9ddfac05519b0037d139a1f43a67ab Mon Sep 17 00:00:00 2001 From: Nicolas Bonamy Date: Sun, 3 Nov 2024 12:10:48 -0600 Subject: [PATCH] multi llm upgrade --- defaults/settings.json | 2 +- package-lock.json | 8 +-- package.json | 2 +- src/automations/commander.ts | 2 +- src/llms/llm-worker.ts | 2 +- src/llms/llm.ts | 98 ++++++++++++++++++++++---------- src/llms/openai.ts | 7 ++- src/plugins/plugin.ts | 3 +- src/screens/DocRepoCreate.vue | 2 +- src/screens/PromptAnywhere.vue | 4 +- src/screens/ScratchPad.vue | 2 +- src/services/assistant-worker.ts | 12 ++-- src/services/assistant.ts | 23 ++++++-- src/settings/SettingsGoogle.vue | 2 +- src/types/config.d.ts | 13 +++-- src/types/index.d.ts | 4 +- tests/mocks/llm.ts | 10 ++-- tests/unit/embedder.test.ts | 12 ++-- tests/unit/llm1.test.ts | 3 +- tests/unit/llm2.test.ts | 50 ++++++++-------- tests/unit/store.test.ts | 16 +++--- 21 files changed, 167 insertions(+), 110 deletions(-) diff --git a/defaults/settings.json b/defaults/settings.json index da41d285..e81c840e 100644 --- a/defaults/settings.json +++ b/defaults/settings.json @@ -18,7 +18,7 @@ "retryGeneration": true } }, -"llm": { + "llm": { "engine": "openai", "autoVisionSwitch": true, "conversationLength": 20 diff --git a/package-lock.json b/package-lock.json index 1dc21827..f58785ff 100644 --- a/package-lock.json +++ b/package-lock.json @@ -37,7 +37,7 @@ "markdown-it-mark": "^4.0.0", "minimatch": "^10.0.1", "mitt": "^3.0.1", - "multi-llm-ts": "^1.3.4", + "multi-llm-ts": "^2.0.0-beta.1", "nestor-client": "^0.3.1", "officeparser": "^4.1.1", "ollama": "^0.5.9", @@ -15776,9 +15776,9 @@ "license": "MIT" }, "node_modules/multi-llm-ts": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/multi-llm-ts/-/multi-llm-ts-1.3.4.tgz", - "integrity": "sha512-fWHo2YEaNTxJ2baYUTiPkFhfVdxjjjAeQEd41MeSrGD2GQOyvB995QqbgfxAhzQlt4LF9zvF8pLDsCwAG2ovdw==", + "version": "2.0.0-beta.1", + "resolved": "https://registry.npmjs.org/multi-llm-ts/-/multi-llm-ts-2.0.0-beta.1.tgz", + "integrity": "sha512-Q+6nc6Gr3AXFRElJy/xHTqh1hhDFXI8XbFIet1xlPHnsS0a/5JdURDAvY5CkGm5acT/sXSm3mHNMgPOIXQrj+w==", "dependencies": { "@anthropic-ai/sdk": "^0.30.1", "@google/generative-ai": "^0.21.0", diff --git a/package.json b/package.json index 739fe453..9a451eb6 100644 --- a/package.json +++ b/package.json @@ -98,7 +98,7 @@ "markdown-it-mark": "^4.0.0", "minimatch": "^10.0.1", "mitt": "^3.0.1", - "multi-llm-ts": "^1.3.4", + "multi-llm-ts": "^2.0.0-beta.1", "nestor-client": "^0.3.1", "officeparser": "^4.1.1", "ollama": "^0.5.9", diff --git a/src/automations/commander.ts b/src/automations/commander.ts index bbdb7147..60542765 100644 --- a/src/automations/commander.ts +++ b/src/automations/commander.ts @@ -204,7 +204,7 @@ export default class Commander { ] // now get it - return llm.complete(messages, { model: model }) + return llm.complete(model, messages) } diff --git a/src/llms/llm-worker.ts b/src/llms/llm-worker.ts index 64d58c23..3cb53b3c 100644 --- a/src/llms/llm-worker.ts +++ b/src/llms/llm-worker.ts @@ -18,7 +18,7 @@ const initEngine = (engine: string, config: Configuration) => { const stream = async (messages: any[], opts: any) => { try { - const stream = await llm.generate(messages, opts) + const stream = await llm.generate(opts.model, messages, opts) for await (const msg of stream) { worker.postMessage(msg) } diff --git a/src/llms/llm.ts b/src/llms/llm.ts index f96deceb..51dd62f6 100644 --- a/src/llms/llm.ts +++ b/src/llms/llm.ts @@ -1,6 +1,6 @@ -import { Configuration } from 'types/config' -import { Anthropic, Ollama, MistralAI, Google, Groq, XAI, Cerebras, LlmEngine , loadAnthropicModels, loadCerebrasModels, loadGoogleModels, loadGroqModels, loadMistralAIModels, loadOllamaModels, loadOpenAIModels, loadXAIModels, hasVisionModels as _hasVisionModels, isVisionModel as _isVisionModel } from 'multi-llm-ts' +import { Configuration, EngineConfig } from 'types/config.d' +import { Anthropic, Ollama, MistralAI, Google, Groq, XAI, Cerebras, LlmEngine, loadAnthropicModels, loadCerebrasModels, loadGoogleModels, loadGroqModels, loadMistralAIModels, loadOllamaModels, loadOpenAIModels, loadXAIModels, hasVisionModels as _hasVisionModels, isVisionModel as _isVisionModel, ModelsList, Model } from 'multi-llm-ts' import { isSpecializedModel as isSpecialAnthropicModel, getFallbackModel as getAnthropicFallbackModel , getComputerInfo } from './anthropic' import { imageFormats, textFormats } from '../models/attachment' import { store } from '../services/store' @@ -60,18 +60,20 @@ export default class LlmFactory { } isEngineReady = (engine: string): boolean => { - if (engine === 'anthropic') return Anthropic.isReady(this.config.engines.anthropic) - if (engine === 'cerebras') return Cerebras.isReady(this.config.engines.cerebras) - if (engine === 'google') return Google.isReady(this.config.engines.google) - if (engine === 'groq') return Groq.isReady(this.config.engines.groq) - if (engine === 'mistralai') return MistralAI.isReady(this.config.engines.mistralai) - if (engine === 'ollama') return Ollama.isReady(this.config.engines.ollama) - if (engine === 'openai') return OpenAI.isReady(this.config.engines.openai) - if (engine === 'xai') return XAI.isReady(this.config.engines.xai) + if (engine === 'anthropic') return Anthropic.isReady(this.config.engines.anthropic, this.config.engines.anthropic?.models) + if (engine === 'cerebras') return Cerebras.isReady(this.config.engines.cerebras, this.config.engines.cerebras?.models) + if (engine === 'google') return Google.isReady(this.config.engines.google, this.config.engines.google?.models) + if (engine === 'groq') return Groq.isReady(this.config.engines.groq, this.config.engines.groq?.models) + if (engine === 'mistralai') return MistralAI.isReady(this.config.engines.mistralai, this.config.engines.mistralai?.models) + if (engine === 'ollama') return Ollama.isReady(this.config.engines.ollama, this.config.engines.ollama?.models) + if (engine === 'openai') return OpenAI.isReady(this.config.engines.openai, this.config.engines.openai?.models) + if (engine === 'xai') return XAI.isReady(this.config.engines.xai, this.config.engines.xai?.models) return false } - igniteEngine = (engine: string, fallback = 'openai'): LlmEngine => { + igniteEngine = (engine: string): LlmEngine => { + + // select if (engine === 'anthropic') return new Anthropic(this.config.engines.anthropic, getComputerInfo()) if (engine === 'cerebras') return new Cerebras(this.config.engines.cerebras) if (engine === 'google') return new Google(this.config.engines.google) @@ -80,15 +82,15 @@ export default class LlmFactory { if (engine === 'ollama') return new Ollama(this.config.engines.ollama) if (engine === 'openai') return new OpenAI(this.config.engines.openai) if (engine === 'xai') return new XAI(this.config.engines.xai) - if (this.isEngineReady(fallback)) { - console.warn(`Engine ${engine} unknown. Falling back to ${fallback}`) - return this.igniteEngine(fallback, this.config.engines[fallback]) - } - return null + + // fallback + console.warn(`Engine ${engine} unknown. Falling back to OpenAI`) + return new OpenAI(this.config.engines.openai) + } hasChatModels = (engine: string) => { - return this.config.engines[engine].models.chat.length > 0 + return this.config.engines[engine].models?.chat?.length > 0 } hasVisionModels = (engine: string) => { @@ -123,32 +125,70 @@ export default class LlmFactory { loadModels = async (engine: string): Promise => { console.log('Loading models for', engine) - let rc = false + let models: ModelsList|null = null if (engine === 'openai') { - rc = await loadOpenAIModels(this.config.engines.openai) + models = await loadOpenAIModels(this.config.engines.openai) } else if (engine === 'ollama') { - rc = await loadOllamaModels(this.config.engines.ollama) + models = await loadOllamaModels(this.config.engines.ollama) } else if (engine === 'mistralai') { - rc = await loadMistralAIModels(this.config.engines.mistralai) + models = await loadMistralAIModels(this.config.engines.mistralai) } else if (engine === 'anthropic') { - rc = await loadAnthropicModels(this.config.engines.anthropic, getComputerInfo()) + models = await loadAnthropicModels(this.config.engines.anthropic, getComputerInfo()) } else if (engine === 'google') { - rc = await loadGoogleModels(this.config.engines.google) + models = await loadGoogleModels(this.config.engines.google) } else if (engine === 'groq') { - rc = await loadGroqModels(this.config.engines.groq) + models = await loadGroqModels(this.config.engines.groq) } else if (engine === 'cerebras') { - rc = await loadCerebrasModels(this.config.engines.cerebras) + models = await loadCerebrasModels(this.config.engines.cerebras) } else if (engine === 'xai') { - rc = await loadXAIModels(this.config.engines.xai) + models = await loadXAIModels(this.config.engines.xai) } - + + // needed + const engineConfig = store.config.engines[engine] + + // check + if (typeof models !== 'object') { + engineConfig.models = { chat: [], image: [] } + return false + } + + // openai names are not great + if (engine === 'openai') { + models.chat = models.chat.map(m => { + let name = m.name + name = name.replace(/^gpt-([^-]*)(-?)([a-z]?)/i, (_, l1, __, l3) => `GPT-${l1} ${l3?.toUpperCase()}`) + name = name.replace('Mini', 'mini') + return { id: m.id, name, meta: m.meta } + }) + models.image = models.image.map(m => { + let name = m.name + name = name.replace(/^dall-e-/i, 'DALL-E ') + return { id: m.id, name, meta: m.meta } + }) + } + + // local function + const getValidModelId = (engineConfig: EngineConfig, type: string, modelId: string) => { + const models: Model[] = engineConfig?.models?.[type as keyof typeof engineConfig.models] + const m = models?.find(m => m.id == modelId) + return m ? modelId : (models?.[0]?.id || null) + } + + // save in store + engineConfig.models = models + engineConfig.model = { + chat: getValidModelId(engineConfig, 'chat', engineConfig.model?.chat), + image: getValidModelId(engineConfig, 'image', engineConfig.model?.image) + } + // save if (this.config == store.config) { store.saveSettings() } - + // done - return rc + return true } diff --git a/src/llms/openai.ts b/src/llms/openai.ts index f5e0c97f..6390edc8 100644 --- a/src/llms/openai.ts +++ b/src/llms/openai.ts @@ -1,14 +1,15 @@ -import { EngineConfig, OpenAI } from "multi-llm-ts"; +import { EngineCreateOpts, ModelsList, OpenAI } from "multi-llm-ts"; export default class extends OpenAI { // eslint-disable-next-line @typescript-eslint/no-unused-vars - static isConfigured = (engineConfig: EngineConfig): boolean => { + static isConfigured = (engineConfig: EngineCreateOpts): boolean => { return true } - static isReady = (engineConfig: EngineConfig): boolean => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + static isReady = (engineConfig: EngineCreateOpts, models: ModelsList): boolean => { return OpenAI.isConfigured(engineConfig) } diff --git a/src/plugins/plugin.ts b/src/plugins/plugin.ts index 58f4ba6e..636716e8 100644 --- a/src/plugins/plugin.ts +++ b/src/plugins/plugin.ts @@ -1,5 +1,6 @@ -import { anyDict, Plugin as PluginBase } from 'multi-llm-ts' +import { anyDict } from 'types/index.d' +import { Plugin as PluginBase } from 'multi-llm-ts' export type PluginConfig = anyDict diff --git a/src/screens/DocRepoCreate.vue b/src/screens/DocRepoCreate.vue index 5ac7c027..a83cdb87 100644 --- a/src/screens/DocRepoCreate.vue +++ b/src/screens/DocRepoCreate.vue @@ -127,7 +127,7 @@ const getModels = async () => { // load const llmFactory = new LlmFactory(store.config) - let success = await llmFactory.loadModels('ollama') + const success = await llmFactory.loadModels('ollama') if (!success) { setEphemeralRefreshLabel('Error!') return diff --git a/src/screens/PromptAnywhere.vue b/src/screens/PromptAnywhere.vue index bc91a9d8..309d081d 100644 --- a/src/screens/PromptAnywhere.vue +++ b/src/screens/PromptAnywhere.vue @@ -345,7 +345,7 @@ const onPrompt = async ({ prompt, attachment, docrepo }: { prompt: string, attac // now generate stopGeneration = false - const stream = await llm.generate(chat.value.messages.slice(0, -1), { model: chat.value.model }) + const stream = await llm.generate(chat.value.model, chat.value.messages.slice(0, -1)) for await (const msg of stream) { if (stopGeneration) { llm.stop(stream) @@ -380,7 +380,7 @@ const saveChat = async () => { // we need a title if (!chat.value.title) { - const title = await llm.complete([...chat.value.messages, new Message('user', store.config.instructions.titling_user)]) + const title = await llm.complete(chat.value.model, [...chat.value.messages, new Message('user', store.config.instructions.titling_user)]) chat.value.title = title.content } diff --git a/src/screens/ScratchPad.vue b/src/screens/ScratchPad.vue index d994379a..ed253887 100644 --- a/src/screens/ScratchPad.vue +++ b/src/screens/ScratchPad.vue @@ -476,7 +476,7 @@ const onSendPrompt = async ({ prompt, attachment, docrepo }: { prompt: string, a try { processing.value = true stopGeneration = false - const stream = await llm.generate(chat.value.messages.slice(0, -1), { model: chat.value.model }) + const stream = await llm.generate(chat.value.model, chat.value.messages.slice(0, -1)) for await (const msg of stream) { if (stopGeneration) { llm.stop(stream) diff --git a/src/services/assistant-worker.ts b/src/services/assistant-worker.ts index a3e813d2..2f272c79 100644 --- a/src/services/assistant-worker.ts +++ b/src/services/assistant-worker.ts @@ -1,4 +1,6 @@ -import { LlmCompletionOpts, LlmChunk } from 'multi-llm-ts' + +import { LlmChunk } from 'multi-llm-ts' +import { AssistantCompletionOpts } from './assistant' import { Configuration } from 'types/config.d' import { DocRepoQueryResponseItem } from 'types/rag.d' import Chat, { defaultTitle } from '../models/chat' @@ -21,7 +23,7 @@ export default class { llm: LlmWorker chat: Chat stream: any - opts: LlmCompletionOpts + opts: AssistantCompletionOpts callback: ChunkCallback sources: DocRepoQueryResponseItem[] @@ -82,7 +84,7 @@ export default class { return this.llm !== null } - async prompt(prompt: string, opts: LlmCompletionOpts, callback: ChunkCallback): Promise { + async prompt(prompt: string, opts: AssistantCompletionOpts, callback: ChunkCallback): Promise { // check prompt = prompt.trim() @@ -91,7 +93,7 @@ export default class { } // merge with defaults - const defaults: LlmCompletionOpts = { + const defaults: AssistantCompletionOpts = { save: true, titling: true, ... this.llmFactory.getChatEngineModel(), @@ -296,7 +298,7 @@ export default class { // now get it this.initLlm(this.chat.engine) - const response = await this.llm.complete(messages, { model: this.chat.model }) + const response = await this.llm.complete(this.chat.model, messages) let title = response.content.trim() if (title === '') { return this.chat.messages[1].content diff --git a/src/services/assistant.ts b/src/services/assistant.ts index c5bcf0c3..f4fef859 100644 --- a/src/services/assistant.ts +++ b/src/services/assistant.ts @@ -10,6 +10,17 @@ import { store } from './store' import { countryCodeToName } from './i18n' import { availablePlugins } from '../plugins/plugins' +export interface AssistantCompletionOpts extends LlmCompletionOpts { + engine?: string + model?: string + save?: boolean + titling?: boolean + attachment?: Attachment + docrepo?: string + overwriteEngineModel?: boolean + systemInstructions?: string +} + export default class { config: Configuration @@ -18,7 +29,7 @@ export default class { llm: LlmEngine chat: Chat stopGeneration: boolean - stream: AsyncGenerator + stream: AsyncIterable constructor(config: Configuration) { this.config = config @@ -59,7 +70,7 @@ export default class { return this.llm !== null } - async prompt(prompt: string, opts: LlmCompletionOpts, callback: (chunk: LlmChunk) => void): Promise { + async prompt(prompt: string, opts: AssistantCompletionOpts, callback: (chunk: LlmChunk) => void): Promise { // check prompt = prompt.trim() @@ -68,7 +79,7 @@ export default class { } // merge with defaults - const defaults: LlmCompletionOpts = { + const defaults: AssistantCompletionOpts = { save: true, titling: true, ... this.llmFactory.getChatEngineModel(), @@ -140,7 +151,7 @@ export default class { } - async generateText(opts: LlmCompletionOpts, callback: (chunk: LlmChunk) => void): Promise { + async generateText(opts: AssistantCompletionOpts, callback: (chunk: LlmChunk) => void): Promise { // we need this to be const during generation const llm = this.llm @@ -166,7 +177,7 @@ export default class { // now stream this.stopGeneration = false - this.stream = await llm.generate(messages, opts) + this.stream = await llm.generate(opts.model, messages, opts) for await (const msg of this.stream) { if (this.stopGeneration) { break @@ -264,7 +275,7 @@ export default class { // now get it this.initLlm(this.chat.engine) - const response = await this.llm.complete(messages, { model: this.chat.model }) + const response = await this.llm.complete(this.chat.model, messages) let title = response.content.trim() if (title === '') { return this.chat.messages[1].content diff --git a/src/settings/SettingsGoogle.vue b/src/settings/SettingsGoogle.vue index 299a401b..ab1a99c1 100644 --- a/src/settings/SettingsGoogle.vue +++ b/src/settings/SettingsGoogle.vue @@ -88,7 +88,7 @@ const save = () => { store.config.engines.google.apiKey = apiKey.value store.config.engines.google.model = { chat: chat_model.value, - //image: image_model.value + image: '' } store.saveSettings() } diff --git a/src/types/config.d.ts b/src/types/config.d.ts index 94e1a28f..5ef7072a 100644 --- a/src/types/config.d.ts +++ b/src/types/config.d.ts @@ -1,4 +1,5 @@ +import { EngineCreateOpts, Model } from 'multi-llm-ts' import { Shortcut, anyDict } from './index.d' export interface Configuration { @@ -18,6 +19,12 @@ export interface Configuration { gdrive: GDriveConfig } +export interface EngineConfig extends EngineCreateOpts{ + models: ModelsConfig + model: ModelConfig + tts: TTSConfig +} + export interface GeneralConfig { firstRun: boolean hideOnStartup: boolean @@ -90,12 +97,6 @@ export interface STTConfig { //silenceAction: SilenceAction } -export interface Model { - id: string - name: string - meta: any -} - export interface ModelsConfig { chat: Model[] image?: Model[] diff --git a/src/types/index.d.ts b/src/types/index.d.ts index 6c6bf4fd..27ae983f 100644 --- a/src/types/index.d.ts +++ b/src/types/index.d.ts @@ -1,9 +1,11 @@ -export { strDict, anyDict } from 'multi-llm-ts' import { LlmChunk, LlmChunkTool, LlmRole, anyDict } from 'multi-llm-ts' import { Configuration } from './config.d' import { ToolCallInfo } from 'models/message' +export type strDict = { [key: string]: string } +export type anyDict = { [key: string]: any } + export interface Attachment { contents: string mimeType: string diff --git a/tests/mocks/llm.ts b/tests/mocks/llm.ts index 70f033c9..503d5dc0 100644 --- a/tests/mocks/llm.ts +++ b/tests/mocks/llm.ts @@ -1,5 +1,5 @@ -import { LlmEngine, LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, EngineConfig } from 'multi-llm-ts' +import { LlmEngine, LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, EngineCreateOpts } from 'multi-llm-ts' import Message from '../../src/models/message' import RandomChunkStream from './stream' @@ -19,7 +19,7 @@ class LlmError extends Error { export default class LlmMock extends LlmEngine { - constructor(config: EngineConfig) { + constructor(config: EngineCreateOpts) { super(config) } @@ -39,8 +39,8 @@ export default class LlmMock extends LlmEngine { ] } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - async complete(thread: Message[], opts: LlmCompletionOpts): Promise { + + async complete(model: string, thread: Message[]): Promise { return { type: 'text', content: JSON.stringify([ @@ -51,7 +51,7 @@ export default class LlmMock extends LlmEngine { } // eslint-disable-next-line @typescript-eslint/no-unused-vars - async stream(thread: Message[], opts: LlmCompletionOpts): Promise { + async stream(model: string, thread: Message[], opts: LlmCompletionOpts): Promise { // errors if (thread[thread.length-1].content.includes('no api key')) { diff --git a/tests/unit/embedder.test.ts b/tests/unit/embedder.test.ts index 4fd5fa58..eb83df95 100644 --- a/tests/unit/embedder.test.ts +++ b/tests/unit/embedder.test.ts @@ -3,8 +3,8 @@ import { test, expect, vi, beforeEach } from 'vitest' import { app } from 'electron' import Embedder from '../../src/rag/embedder' import defaultSettings from '../../defaults/settings.json' -import * as _ollama from 'ollama/dist/browser.mjs' -import * as _OpenAI from 'openai' +import { Ollama } from 'ollama/dist/browser.mjs' +import OpenAI from 'openai' vi.mock('openai', async () => { const OpenAI = vi.fn() @@ -55,8 +55,8 @@ test('Create OpenAI', async () => { test('Embed OpenAI', async () => { const embedder = await Embedder.init(app, defaultSettings, 'openai', 'text-embedding-ada-002') const embeddings = await embedder.embed(['hello']) - expect(_OpenAI.default.prototype.embeddings.create).toHaveBeenCalled() - expect(_ollama.Ollama.prototype.embed).not.toHaveBeenCalled() + expect(OpenAI.prototype.embeddings.create).toHaveBeenCalled() + expect(Ollama.prototype.embed).not.toHaveBeenCalled() expect(embeddings).toStrictEqual([[111, 108, 108, 101, 104]]) }) @@ -70,8 +70,8 @@ test('Create Ollama', async () => { test('Embed Ollama', async () => { const embedder = await Embedder.init(app, defaultSettings, 'ollama', 'all-minilm') const embeddings = await embedder.embed(['hello']) - expect(_ollama.Ollama.prototype.embed).toHaveBeenCalled() - expect(_OpenAI.default.prototype.embeddings.create).not.toHaveBeenCalled() + expect(Ollama.prototype.embed).toHaveBeenCalled() + expect(OpenAI.prototype.embeddings.create).not.toHaveBeenCalled() expect(embeddings).toStrictEqual([[111, 108, 108, 101, 104]]) }) diff --git a/tests/unit/llm1.test.ts b/tests/unit/llm1.test.ts index af90fe02..e788b795 100644 --- a/tests/unit/llm1.test.ts +++ b/tests/unit/llm1.test.ts @@ -27,7 +27,7 @@ store.config.engines.openai.apiKey = '123' const llmFactory = new LlmFactory(store.config) -const model = [{ id: 'llava:latest', name: 'llava:latest', meta: {} }] +const model = { id: 'model-id', name: 'model-name', meta: {} } test('Default Configuration', () => { expect(llmFactory.isEngineReady('openai')).toBe(true) @@ -138,7 +138,6 @@ test('Ignite Engine', async () => { expect(await llmFactory.igniteEngine('groq')).toBeInstanceOf(Groq) expect(await llmFactory.igniteEngine('cerebras')).toBeInstanceOf(Cerebras) expect(await llmFactory.igniteEngine('aws')).toBeInstanceOf(OpenAI) - expect(await llmFactory.igniteEngine('aws', 'aws')).toBeNull() }) test('Anthropic Computer Use', async () => { diff --git a/tests/unit/llm2.test.ts b/tests/unit/llm2.test.ts index 80cbe5b8..eeb25821 100644 --- a/tests/unit/llm2.test.ts +++ b/tests/unit/llm2.test.ts @@ -1,6 +1,6 @@ import { vi, expect, test } from 'vitest' -import * as _MultiLLM from 'multi-llm-ts' +import { ModelsList, loadAnthropicModels, loadCerebrasModels, loadGoogleModels, loadGroqModels, loadMistralAIModels, loadOllamaModels, loadOpenAIModels, loadXAIModels }from 'multi-llm-ts' import { store } from '../../src/services/store' import LlmFactory from '../../src/llms/llm' @@ -8,14 +8,14 @@ vi.mock('multi-llm-ts', async (importOriginal) => { const mod: any = await importOriginal() return { ...mod, - loadAnthropicModels: vi.fn(() => []), - loadCerebrasModels: vi.fn(() => []), - loadGoogleModels: vi.fn(() => []), - loadGroqModels: vi.fn(() => []), - loadMistralAIModels: vi.fn(() => []), - loadOllamaModels: vi.fn(() => []), - loadOpenAIModels: vi.fn(() => []), - loadXAIModels: vi.fn(() => []), + loadAnthropicModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadCerebrasModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadGoogleModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadGroqModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadMistralAIModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadOllamaModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadOpenAIModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), + loadXAIModels: vi.fn((): ModelsList => ({ chat: [], image: [] })), } }) @@ -48,34 +48,34 @@ const llmFactory = new LlmFactory(store.config) test('Load models', async () => { await llmFactory.loadModels('anthropic') - expect(_MultiLLM.loadAnthropicModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(1) + expect(loadAnthropicModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(1) await llmFactory.loadModels('cerebras') - expect(_MultiLLM.loadCerebrasModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(2) + expect(loadCerebrasModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(2) await llmFactory.loadModels('google') - expect(_MultiLLM.loadGoogleModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(3) + expect(loadGoogleModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(3) await llmFactory.loadModels('groq') - expect(_MultiLLM.loadGroqModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(4) + expect(loadGroqModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(4) await llmFactory.loadModels('mistralai') - expect(_MultiLLM.loadMistralAIModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(5) + expect(loadMistralAIModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(5) await llmFactory.loadModels('ollama') - expect(_MultiLLM.loadOllamaModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(6) + expect(loadOllamaModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(6) await llmFactory.loadModels('openai') - expect(_MultiLLM.loadOpenAIModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(7) + expect(loadOpenAIModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(7) await llmFactory.loadModels('xai') - expect(_MultiLLM.loadXAIModels).toHaveBeenCalledTimes(1) - expect(window.api.config.save).toHaveBeenCalledTimes(8) + expect(loadXAIModels).toHaveBeenCalledTimes(1) + expect(window.api.config?.save).toHaveBeenCalledTimes(8) }) diff --git a/tests/unit/store.test.ts b/tests/unit/store.test.ts index c306ba05..200d2c17 100644 --- a/tests/unit/store.test.ts +++ b/tests/unit/store.test.ts @@ -66,10 +66,10 @@ test('Check atributtes', async () => { test('Load', async () => { store.load() - expect(window.api.config.load).toHaveBeenCalled() - expect(window.api.experts.load).toHaveBeenCalled() - expect(window.api.commands.load).toHaveBeenCalled() - expect(window.api.history.load).toHaveBeenCalled() + expect(window.api.config?.load).toHaveBeenCalled() + expect(window.api.experts?.load).toHaveBeenCalled() + expect(window.api.commands?.load).toHaveBeenCalled() + expect(window.api.history?.load).toHaveBeenCalled() expect(store.config).toStrictEqual(defaultSettings) expect(store.commands).toStrictEqual(defaultCommands) expect(store.experts).toStrictEqual(defaultExperts) @@ -78,19 +78,19 @@ test('Load', async () => { test('Save settings', async () => { store.load() store.saveSettings() - expect(window.api.config.save).toHaveBeenCalled() + expect(window.api.config?.save).toHaveBeenCalled() }) test('Reload settings without changing reference', async () => { store.load() - expect(window.api.config.load).toHaveBeenCalledTimes(1) + expect(window.api.config?.load).toHaveBeenCalledTimes(1) const backup = store.config expect(store.config.llm.engine).toBe('openai') expect(store.config.plugins).toBeDefined() defaultSettings.llm.engine = 'xai' delete defaultSettings.plugins listeners[0]('settings') - expect(window.api.config.load).toHaveBeenCalledTimes(2) + expect(window.api.config?.load).toHaveBeenCalledTimes(2) expect(store.config).toBe(backup) expect(store.config.llm.engine).toBe('xai') expect(store.config.plugins).toBeUndefined() @@ -105,7 +105,7 @@ test('Load history', async () => { test('Save history', async () => { store.saveHistory() - expect(window.api.history.save).toHaveBeenCalledWith([ { + expect(window.api.history?.save).toHaveBeenCalledWith([ { uuid: '123', engine: 'engine', model: 'model',