Skip to content

Commit

Permalink
feat: support model for cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
rxliuli committed Aug 30, 2024
1 parent e724090 commit 1364082
Show file tree
Hide file tree
Showing 8 changed files with 1,624 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ Supported models
- [x] Groq
- [x] Cerebras
- [x] Azure OpenAI
- [x] Cohere
- [x] Aliyun Bailian
- [ ] Ollama
- [ ] Cloudflare Workers AI
- [ ] Coze

## Deployment

Expand Down Expand Up @@ -46,6 +51,11 @@ Environment variables
- `AZURE_OPENAI_ENDPOINT`: Azure OpenAI Endpoint
- `AZURE_API_VERSION`: Azure OpenAI API Version
- `AZURE_DEPLOYMENT_MODELS`: Azure OpenAI Deployment Models, such as `gpt-4o-mini:gpt-4o-mini-dev,gpt-35-turbo:gpt-35-dev`, represent two models, `gpt-4o-mini` and `gpt-35-turbo`, corresponding to two deployments, `gpt-4o-mini-dev` and `gpt-35-dev` respectively.
- Cohere: Supports Cohere models, e.g. `command-r`
- `COHERE_API_KEY`: Cohere API Key
- Aliyun Bailian: Supports Aliyun Bailian models, e.g. `qwen-max`
- `ALIYUN_BAILIAN_API_KEY`: Aliyun Bailian API Key
- `ALIYUN_BAILIAN_MODELS`: Custom supported Aliyun Bailian models, e.g. `qwen-max,qwen-7b-chat`, default to `qwen-max`

## Usage

Expand Down
9 changes: 7 additions & 2 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
- [x] 零一万物
- [x] Cerebras
- [x] Azure OpenAI
- [ ] Aliyun Bailian
- [x] Cohere
- [x] Aliyun Bailian
- [ ] Ollama
- [ ] Cloudflare Workers AI
- [ ] Cohere
- [ ] Coze
- [ ] 豆包

Expand Down Expand Up @@ -58,6 +58,11 @@
- `AZURE_OPENAI_ENDPOINT`: Azure OpenAI Endpoint
- `AZURE_API_VERSION`: Azure OpenAI API Version
- `AZURE_DEPLOYMENT_MODELS`: Azure OpenAI Deployment Models, 例如 `gpt-4o-mini:gpt-4o-mini-dev,gpt-35-turbo:gpt-35-dev`,表示 `gpt-4o-mini``gpt-35-turbo` 两个模型,分别对应 `gpt-4o-mini-dev``gpt-35-dev` 两个部署。
- Cohere: 支持 Cohere 模型,例如 `command-r`
- `COHERE_API_KEY`: Cohere API Key
- Aliyun Bailian: 支持阿里云百炼模型,例如 `qwen-max`
- `ALIYUN_BAILIAN_API_KEY`: 阿里云百炼 API Key
- `ALIYUN_BAILIAN_MODELS`: 自定义支持的阿里云百炼模型列表,例如 `qwen-max,qwen-7b-chat`,默认为 `qwen-max`

## 使用

Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"@anthropic-ai/vertex-sdk": "^0.4.1",
"@azure/identity": "^4.4.1",
"@google/generative-ai": "^0.17.1",
"cohere-ai": "^7.13.0",
"hono": "^4.5.7",
"jose": "^5.7.0",
"langchain-anthropic-edge": "workspace:*",
Expand Down
1,373 changes: 1,373 additions & 0 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { streamSSE } from 'hono/streaming'
import { openai } from './llm/openai'
import { anthropic, anthropicVertex } from './llm/anthropic'
import OpenAI from 'openai'
import { uniq, uniqBy } from 'lodash-es'
import { google } from './llm/google'
import { deepseek } from './llm/deepseek'
import { serializeError } from 'serialize-error'
Expand All @@ -13,6 +12,8 @@ import { moonshot } from './llm/moonshot'
import { lingyiwanwu } from './llm/lingyiwanwu'
import { groq } from './llm/groq'
import { auzreOpenAI } from './llm/azure'
import { bailian } from './llm/bailian'
import { cohere } from './llm/cohere'

interface Bindings {
API_KEY: string
Expand All @@ -30,6 +31,8 @@ function getModels(env: Record<string, string>) {
lingyiwanwu(env),
groq(env),
auzreOpenAI(env),
cohere(env),
bailian(env),
].filter((it) => it.requiredEnv.every((it) => it in env))
}

Expand Down
41 changes: 41 additions & 0 deletions src/llm/__tests__/cohere.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { beforeEach, it, expect } from 'vitest'
import { IChat } from '../base'
import { cohere } from '../cohere'
import OpenAI from 'openai'
import { last } from 'lodash-es'

let client: IChat
beforeEach(() => {
client = cohere({
COHERE_API_KEY: import.meta.env.VITE_COHERE_API_KEY,
})
})
it('invoke', async () => {
const res = await client.invoke({
model: 'command-r-plus',
messages: [{ role: 'user', content: '你是?' }],
temperature: 0,
})
console.log('converted: ', res)
expect(res.choices[0].message.content).not.undefined
})

it('stream', async () => {
const stream = client.stream(
{
model: 'command-r-plus',
messages: [{ role: 'user', content: '你是?' }],
temperature: 0,
stream: true,
},
new AbortController().signal,
)
const chunks: OpenAI.ChatCompletionChunk[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
expect(chunks).not.empty
const content = chunks.map((it) => it.choices[0]?.delta.content).join('')
expect(content).not.empty
expect(last(chunks)!.usage).not.undefined
})
4 changes: 4 additions & 0 deletions src/llm/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ export interface IChat {
signal: AbortSignal,
): AsyncGenerator<ChatCompletionChunk>
}

export interface ChatEnv {
[key: string]: string | undefined
}
184 changes: 184 additions & 0 deletions src/llm/cohere.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import { last } from 'lodash-es'
import { ChatEnv, IChat } from './base'
import { Cohere, CohereClient } from 'cohere-ai'
import OpenAI from 'openai'

function roleOpenaiToCohere(
role: OpenAI.Chat.Completions.ChatCompletionMessageParam['role'],
): Cohere.Message['role'] {
switch (role) {
case 'user':
return 'USER'
case 'assistant':
return 'CHATBOT'
case 'system':
return 'SYSTEM'
case 'tool':
return 'TOOL'
case 'function':
default:
throw new Error(`Unsupported role: ${role}`)
}
}

function finishReasonCohereToOpenAI(
reason: Cohere.NonStreamedChatResponse['finishReason'],
): OpenAI.ChatCompletion['choices'][number]['finish_reason'] {
switch (reason) {
case 'COMPLETE':
return 'stop'
case 'ERROR_LIMIT':
return 'stop'
case 'STOP_SEQUENCE':
return 'stop'
case 'USER_CANCEL':
return 'stop'
case 'MAX_TOKENS':
return 'length'
case 'ERROR_TOXIC':
return 'content_filter'
case 'ERROR':
default:
return 'stop'
}
}

export function cohere(env: ChatEnv): IChat {
const createClient = () =>
new CohereClient({
token: env.COHERE_API_KEY,
})

function parseRequest(
req: OpenAI.ChatCompletionCreateParams,
): Cohere.ChatRequest {
return {
model: req.model,
chatHistory: req.messages.slice(0, -1).map(
(it) =>
({
role: roleOpenaiToCohere(it.role),
message: it.content as string,
} satisfies Cohere.Message),
),
message: last(req.messages)?.content as string,
temperature: req.temperature!,
maxTokens: req.max_tokens!,
}
}

function parseResponse(
resp: Cohere.NonStreamedChatResponse,
req: OpenAI.ChatCompletionCreateParams,
): OpenAI.ChatCompletion {
return {
id: resp.generationId ?? 'chatcmpl-' + crypto.randomUUID(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: req.model,
choices: [
{
message: {
role: 'assistant',
content: resp.text,
refusal: null,
},
finish_reason: finishReasonCohereToOpenAI(resp.finishReason),
index: 0,
logprobs: null,
},
],
usage:
resp.meta?.billedUnits?.inputTokens &&
resp.meta?.billedUnits?.outputTokens
? {
prompt_tokens: resp.meta.billedUnits.inputTokens,
completion_tokens: resp.meta.billedUnits.outputTokens,
total_tokens:
resp.meta.billedUnits.inputTokens +
resp.meta.billedUnits.outputTokens,
}
: undefined,
}
}

return {
name: 'cohere',
supportModels: [
'command-r',
'command-r-plus',
'command-nightly',
'command-light-nightly',
'command',
'command-r-08-2024',
'command-r-plus-08-2024',
'command-light',
],
requiredEnv: ['COHERE_API_KEY'],
async invoke(req) {
const client = createClient()
const resp = await client.chat(parseRequest(req))
return parseResponse(resp, req)
},
async *stream(req, signal) {
const client = createClient()
const stream = await client.chatStream(parseRequest(req), {
abortSignal: signal,
})
let id: string = 'chatcmpl-' + crypto.randomUUID()
const fileds = () =>
({
id: id,
model: req.model,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
} as const)
for await (const chunk of stream) {
if (chunk.eventType === 'stream-start') {
id = chunk.generationId
yield {
...fileds(),
choices: [
{ delta: { content: '' }, index: 0, finish_reason: null },
],
} satisfies OpenAI.ChatCompletionChunk
} else if (chunk.eventType === 'text-generation') {
yield {
...fileds(),
choices: [
{
delta: { content: chunk.text },
index: 0,
finish_reason: null,
},
],
} satisfies OpenAI.ChatCompletionChunk
} else if (chunk.eventType === 'stream-end') {
if (
!chunk.response.meta?.billedUnits?.inputTokens ||
!chunk.response.meta?.billedUnits?.outputTokens
) {
throw new Error('Billed units not found')
}
yield {
...fileds(),
choices: [
{
delta: { content: '' },
index: 0,
finish_reason: finishReasonCohereToOpenAI(chunk.finishReason),
},
],
usage: {
prompt_tokens: chunk.response.meta.billedUnits.inputTokens,
completion_tokens: chunk.response.meta.billedUnits.outputTokens,
total_tokens:
chunk.response.meta.billedUnits.inputTokens +
chunk.response.meta.billedUnits.outputTokens,
},
} satisfies OpenAI.ChatCompletionChunk
}
}
},
}
}

0 comments on commit 1364082

Please sign in to comment.