Skip to content

Commit

Permalink
multi llm upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed Nov 3, 2024
1 parent 5b64601 commit bb597f7
Show file tree
Hide file tree
Showing 21 changed files with 167 additions and 110 deletions.
2 changes: 1 addition & 1 deletion defaults/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"retryGeneration": true
}
},
"llm": {
"llm": {
"engine": "openai",
"autoVisionSwitch": true,
"conversationLength": 20
Expand Down
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"markdown-it-mark": "^4.0.0",
"minimatch": "^10.0.1",
"mitt": "^3.0.1",
"multi-llm-ts": "^1.3.4",
"multi-llm-ts": "^2.0.0-beta.1",
"nestor-client": "^0.3.1",
"officeparser": "^4.1.1",
"ollama": "^0.5.9",
Expand Down
2 changes: 1 addition & 1 deletion src/automations/commander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export default class Commander {
]

// now get it
return llm.complete(messages, { model: model })
return llm.complete(model, messages)

}

Expand Down
2 changes: 1 addition & 1 deletion src/llms/llm-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const initEngine = (engine: string, config: Configuration) => {
const stream = async (messages: any[], opts: any) => {

try {
const stream = await llm.generate(messages, opts)
const stream = await llm.generate(opts.model, messages, opts)
for await (const msg of stream) {
worker.postMessage(msg)
}
Expand Down
98 changes: 69 additions & 29 deletions src/llms/llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import { Configuration } from 'types/config'
import { Anthropic, Ollama, MistralAI, Google, Groq, XAI, Cerebras, LlmEngine , loadAnthropicModels, loadCerebrasModels, loadGoogleModels, loadGroqModels, loadMistralAIModels, loadOllamaModels, loadOpenAIModels, loadXAIModels, hasVisionModels as _hasVisionModels, isVisionModel as _isVisionModel } from 'multi-llm-ts'
import { Configuration, EngineConfig } from 'types/config.d'
import { Anthropic, Ollama, MistralAI, Google, Groq, XAI, Cerebras, LlmEngine, loadAnthropicModels, loadCerebrasModels, loadGoogleModels, loadGroqModels, loadMistralAIModels, loadOllamaModels, loadOpenAIModels, loadXAIModels, hasVisionModels as _hasVisionModels, isVisionModel as _isVisionModel, ModelsList, Model } from 'multi-llm-ts'
import { isSpecializedModel as isSpecialAnthropicModel, getFallbackModel as getAnthropicFallbackModel , getComputerInfo } from './anthropic'
import { imageFormats, textFormats } from '../models/attachment'
import { store } from '../services/store'
Expand Down Expand Up @@ -60,18 +60,20 @@ export default class LlmFactory {
}

isEngineReady = (engine: string): boolean => {
if (engine === 'anthropic') return Anthropic.isReady(this.config.engines.anthropic)
if (engine === 'cerebras') return Cerebras.isReady(this.config.engines.cerebras)
if (engine === 'google') return Google.isReady(this.config.engines.google)
if (engine === 'groq') return Groq.isReady(this.config.engines.groq)
if (engine === 'mistralai') return MistralAI.isReady(this.config.engines.mistralai)
if (engine === 'ollama') return Ollama.isReady(this.config.engines.ollama)
if (engine === 'openai') return OpenAI.isReady(this.config.engines.openai)
if (engine === 'xai') return XAI.isReady(this.config.engines.xai)
if (engine === 'anthropic') return Anthropic.isReady(this.config.engines.anthropic, this.config.engines.anthropic?.models)
if (engine === 'cerebras') return Cerebras.isReady(this.config.engines.cerebras, this.config.engines.cerebras?.models)
if (engine === 'google') return Google.isReady(this.config.engines.google, this.config.engines.google?.models)
if (engine === 'groq') return Groq.isReady(this.config.engines.groq, this.config.engines.groq?.models)
if (engine === 'mistralai') return MistralAI.isReady(this.config.engines.mistralai, this.config.engines.mistralai?.models)
if (engine === 'ollama') return Ollama.isReady(this.config.engines.ollama, this.config.engines.ollama?.models)
if (engine === 'openai') return OpenAI.isReady(this.config.engines.openai, this.config.engines.openai?.models)
if (engine === 'xai') return XAI.isReady(this.config.engines.xai, this.config.engines.xai?.models)
return false
}

igniteEngine = (engine: string, fallback = 'openai'): LlmEngine => {
igniteEngine = (engine: string): LlmEngine => {

// select
if (engine === 'anthropic') return new Anthropic(this.config.engines.anthropic, getComputerInfo())
if (engine === 'cerebras') return new Cerebras(this.config.engines.cerebras)
if (engine === 'google') return new Google(this.config.engines.google)
Expand All @@ -80,15 +82,15 @@ export default class LlmFactory {
if (engine === 'ollama') return new Ollama(this.config.engines.ollama)
if (engine === 'openai') return new OpenAI(this.config.engines.openai)
if (engine === 'xai') return new XAI(this.config.engines.xai)
if (this.isEngineReady(fallback)) {
console.warn(`Engine ${engine} unknown. Falling back to ${fallback}`)
return this.igniteEngine(fallback, this.config.engines[fallback])
}
return null

// fallback
console.warn(`Engine ${engine} unknown. Falling back to OpenAI`)
return new OpenAI(this.config.engines.openai)

}

hasChatModels = (engine: string) => {
return this.config.engines[engine].models.chat.length > 0
return this.config.engines[engine].models?.chat?.length > 0
}

hasVisionModels = (engine: string) => {
Expand Down Expand Up @@ -123,32 +125,70 @@ export default class LlmFactory {
loadModels = async (engine: string): Promise<boolean> => {

console.log('Loading models for', engine)
let rc = false
let models: ModelsList|null = null
if (engine === 'openai') {
rc = await loadOpenAIModels(this.config.engines.openai)
models = await loadOpenAIModels(this.config.engines.openai)
} else if (engine === 'ollama') {
rc = await loadOllamaModels(this.config.engines.ollama)
models = await loadOllamaModels(this.config.engines.ollama)
} else if (engine === 'mistralai') {
rc = await loadMistralAIModels(this.config.engines.mistralai)
models = await loadMistralAIModels(this.config.engines.mistralai)
} else if (engine === 'anthropic') {
rc = await loadAnthropicModels(this.config.engines.anthropic, getComputerInfo())
models = await loadAnthropicModels(this.config.engines.anthropic, getComputerInfo())
} else if (engine === 'google') {
rc = await loadGoogleModels(this.config.engines.google)
models = await loadGoogleModels(this.config.engines.google)
} else if (engine === 'groq') {
rc = await loadGroqModels(this.config.engines.groq)
models = await loadGroqModels(this.config.engines.groq)
} else if (engine === 'cerebras') {
rc = await loadCerebrasModels(this.config.engines.cerebras)
models = await loadCerebrasModels(this.config.engines.cerebras)
} else if (engine === 'xai') {
rc = await loadXAIModels(this.config.engines.xai)
models = await loadXAIModels(this.config.engines.xai)
}


// needed
const engineConfig = store.config.engines[engine]

// check
if (typeof models !== 'object') {
engineConfig.models = { chat: [], image: [] }
return false
}

// openai names are not great
if (engine === 'openai') {
models.chat = models.chat.map(m => {
let name = m.name
name = name.replace(/^gpt-([^-]*)(-?)([a-z]?)/i, (_, l1, __, l3) => `GPT-${l1} ${l3?.toUpperCase()}`)
name = name.replace('Mini', 'mini')
return { id: m.id, name, meta: m.meta }
})
models.image = models.image.map(m => {
let name = m.name
name = name.replace(/^dall-e-/i, 'DALL-E ')
return { id: m.id, name, meta: m.meta }
})
}

// local function
const getValidModelId = (engineConfig: EngineConfig, type: string, modelId: string) => {
const models: Model[] = engineConfig?.models?.[type as keyof typeof engineConfig.models]
const m = models?.find(m => m.id == modelId)
return m ? modelId : (models?.[0]?.id || null)
}

// save in store
engineConfig.models = models
engineConfig.model = {
chat: getValidModelId(engineConfig, 'chat', engineConfig.model?.chat),
image: getValidModelId(engineConfig, 'image', engineConfig.model?.image)
}

// save
if (this.config == store.config) {
store.saveSettings()
}

// done
return rc
return true

}

Expand Down
7 changes: 4 additions & 3 deletions src/llms/openai.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

import { EngineConfig, OpenAI } from "multi-llm-ts";
import { EngineCreateOpts, ModelsList, OpenAI } from "multi-llm-ts";

export default class extends OpenAI {

// eslint-disable-next-line @typescript-eslint/no-unused-vars
static isConfigured = (engineConfig: EngineConfig): boolean => {
static isConfigured = (engineConfig: EngineCreateOpts): boolean => {
return true
}

static isReady = (engineConfig: EngineConfig): boolean => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
static isReady = (engineConfig: EngineCreateOpts, models: ModelsList): boolean => {
return OpenAI.isConfigured(engineConfig)
}

Expand Down
3 changes: 2 additions & 1 deletion src/plugins/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import { anyDict, Plugin as PluginBase } from 'multi-llm-ts'
import { anyDict } from 'types/index.d'
import { Plugin as PluginBase } from 'multi-llm-ts'

export type PluginConfig = anyDict

Expand Down
2 changes: 1 addition & 1 deletion src/screens/DocRepoCreate.vue
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ const getModels = async () => {
// load
const llmFactory = new LlmFactory(store.config)
let success = await llmFactory.loadModels('ollama')
const success = await llmFactory.loadModels('ollama')
if (!success) {
setEphemeralRefreshLabel('Error!')
return
Expand Down
4 changes: 2 additions & 2 deletions src/screens/PromptAnywhere.vue
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ const onPrompt = async ({ prompt, attachment, docrepo }: { prompt: string, attac
// now generate
stopGeneration = false
const stream = await llm.generate(chat.value.messages.slice(0, -1), { model: chat.value.model })
const stream = await llm.generate(chat.value.model, chat.value.messages.slice(0, -1))
for await (const msg of stream) {
if (stopGeneration) {
llm.stop(stream)
Expand Down Expand Up @@ -380,7 +380,7 @@ const saveChat = async () => {
// we need a title
if (!chat.value.title) {
const title = await llm.complete([...chat.value.messages, new Message('user', store.config.instructions.titling_user)])
const title = await llm.complete(chat.value.model, [...chat.value.messages, new Message('user', store.config.instructions.titling_user)])
chat.value.title = title.content
}
Expand Down
2 changes: 1 addition & 1 deletion src/screens/ScratchPad.vue
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ const onSendPrompt = async ({ prompt, attachment, docrepo }: { prompt: string, a
try {
processing.value = true
stopGeneration = false
const stream = await llm.generate(chat.value.messages.slice(0, -1), { model: chat.value.model })
const stream = await llm.generate(chat.value.model, chat.value.messages.slice(0, -1))
for await (const msg of stream) {
if (stopGeneration) {
llm.stop(stream)
Expand Down
12 changes: 7 additions & 5 deletions src/services/assistant-worker.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { LlmCompletionOpts, LlmChunk } from 'multi-llm-ts'

import { LlmChunk } from 'multi-llm-ts'
import { AssistantCompletionOpts } from './assistant'
import { Configuration } from 'types/config.d'
import { DocRepoQueryResponseItem } from 'types/rag.d'
import Chat, { defaultTitle } from '../models/chat'
Expand All @@ -21,7 +23,7 @@ export default class {
llm: LlmWorker
chat: Chat
stream: any
opts: LlmCompletionOpts
opts: AssistantCompletionOpts
callback: ChunkCallback
sources: DocRepoQueryResponseItem[]

Expand Down Expand Up @@ -82,7 +84,7 @@ export default class {
return this.llm !== null
}

async prompt(prompt: string, opts: LlmCompletionOpts, callback: ChunkCallback): Promise<void> {
async prompt(prompt: string, opts: AssistantCompletionOpts, callback: ChunkCallback): Promise<void> {

// check
prompt = prompt.trim()
Expand All @@ -91,7 +93,7 @@ export default class {
}

// merge with defaults
const defaults: LlmCompletionOpts = {
const defaults: AssistantCompletionOpts = {
save: true,
titling: true,
... this.llmFactory.getChatEngineModel(),
Expand Down Expand Up @@ -296,7 +298,7 @@ export default class {

// now get it
this.initLlm(this.chat.engine)
const response = await this.llm.complete(messages, { model: this.chat.model })
const response = await this.llm.complete(this.chat.model, messages)
let title = response.content.trim()
if (title === '') {
return this.chat.messages[1].content
Expand Down
23 changes: 17 additions & 6 deletions src/services/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ import { store } from './store'
import { countryCodeToName } from './i18n'
import { availablePlugins } from '../plugins/plugins'

export interface AssistantCompletionOpts extends LlmCompletionOpts {
engine?: string
model?: string
save?: boolean
titling?: boolean
attachment?: Attachment
docrepo?: string
overwriteEngineModel?: boolean
systemInstructions?: string
}

export default class {

config: Configuration
Expand All @@ -18,7 +29,7 @@ export default class {
llm: LlmEngine
chat: Chat
stopGeneration: boolean
stream: AsyncGenerator<LlmChunk, void, void>
stream: AsyncIterable<LlmChunk>

constructor(config: Configuration) {
this.config = config
Expand Down Expand Up @@ -59,7 +70,7 @@ export default class {
return this.llm !== null
}

async prompt(prompt: string, opts: LlmCompletionOpts, callback: (chunk: LlmChunk) => void): Promise<void> {
async prompt(prompt: string, opts: AssistantCompletionOpts, callback: (chunk: LlmChunk) => void): Promise<void> {

// check
prompt = prompt.trim()
Expand All @@ -68,7 +79,7 @@ export default class {
}

// merge with defaults
const defaults: LlmCompletionOpts = {
const defaults: AssistantCompletionOpts = {
save: true,
titling: true,
... this.llmFactory.getChatEngineModel(),
Expand Down Expand Up @@ -140,7 +151,7 @@ export default class {

}

async generateText(opts: LlmCompletionOpts, callback: (chunk: LlmChunk) => void): Promise<void> {
async generateText(opts: AssistantCompletionOpts, callback: (chunk: LlmChunk) => void): Promise<void> {

// we need this to be const during generation
const llm = this.llm
Expand All @@ -166,7 +177,7 @@ export default class {

// now stream
this.stopGeneration = false
this.stream = await llm.generate(messages, opts)
this.stream = await llm.generate(opts.model, messages, opts)
for await (const msg of this.stream) {
if (this.stopGeneration) {
break
Expand Down Expand Up @@ -264,7 +275,7 @@ export default class {

// now get it
this.initLlm(this.chat.engine)
const response = await this.llm.complete(messages, { model: this.chat.model })
const response = await this.llm.complete(this.chat.model, messages)
let title = response.content.trim()
if (title === '') {
return this.chat.messages[1].content
Expand Down
2 changes: 1 addition & 1 deletion src/settings/SettingsGoogle.vue
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ const save = () => {
store.config.engines.google.apiKey = apiKey.value
store.config.engines.google.model = {
chat: chat_model.value,
//image: image_model.value
image: ''
}
store.saveSettings()
}
Expand Down
Loading

0 comments on commit bb597f7

Please sign in to comment.