Skip to content

Commit

Permalink
feat: support to use valibot to define tools
Browse files Browse the repository at this point in the history
Signed-off-by: Neko Ayaka <[email protected]>
  • Loading branch information
nekomeowww committed Oct 19, 2024
1 parent 0de0cf3 commit 8fe325d
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 6 deletions.
4 changes: 3 additions & 1 deletion examples/neuri/weather-query/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
"keywords": [],
"main": "main.ts",
"scripts": {
"run": "tsx src/main.ts"
"run:zod": "tsx src/zod.ts",
"run:valibot": "tsx src/valibot.ts"
},
"dependencies": {
"neuri": "workspace:^",
"openai": "^4.54.0",
"valibot": "^0.42.1",
"zod": "^3.23.8"
},
"devDependencies": {
Expand Down
89 changes: 89 additions & 0 deletions examples/neuri/weather-query/src/valibot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { env } from 'node:process'
import OpenAI from 'openai'
import * as v from 'valibot'

import {

Check failure on line 5 in examples/neuri/weather-query/src/valibot.ts

View workflow job for this annotation

GitHub Actions / Lint - 20.x

Expected "neuri/openai" to come before "valibot"

Check failure on line 5 in examples/neuri/weather-query/src/valibot.ts

View workflow job for this annotation

GitHub Actions / Lint - 22.x

Expected "neuri/openai" to come before "valibot"
composeAgent,
defineToolFunction,

resolveFirstTextMessageFromCompletion,
system,
toolFunction,
user,
} from 'neuri/openai'

async function main() {
const o = new OpenAI({
baseURL: env.OPENAI_API_BASEURL,
apiKey: env.OPENAI_API_KEY,
})

const { call } = composeAgent({
openAI: o,
tools: [
defineToolFunction(
toolFunction('getCity', 'Get the user\'s city', {}),
async () => {
return 'New York City'
},
{
hooks: {
preInvoke: async () => {
// eslint-disable-next-line no-console
console.log('getCity called')
},
},
},
),
defineToolFunction<{ location: string }, string>(
toolFunction('getCityCode', 'Get the user\'s city code with search', v.object({
location: v.pipe(v.string(), v.minLength(1), v.description('Get the user\'s city code with search')),
})),
async () => {
return 'NYC'
},
{
hooks: {
preInvoke: async () => {
// eslint-disable-next-line no-console
console.log('getCityCode called')
},
},
},
),
defineToolFunction<{ cityCode: string }, { city: string, cityCode: string, weather: string, degreesCelsius: number }>(
toolFunction('getWeather', 'Get the current weather', v.object({
cityCode: v.pipe(v.string(), v.minLength(1), v.description('Get the user\'s city code with search')),
})),
async ({ parameters: { cityCode } }) => {
return {
city: `New York city`,
cityCode,
weather: 'sunny',
degreesCelsius: 26,
}
},
{
hooks: {
preInvoke: async () => {
// eslint-disable-next-line no-console
console.log('getWeather called')
},
},
},
),
],
})

const res = await call([
system('I am a helpful assistant here to provide information of user, user may ask you anything. Please identify the user\'s need, and pick up the right tool to obtain the necessary information.'),
user('What is the weather like today?'),
], {
model: 'openai/gpt-3.5-turbo',
})

return resolveFirstTextMessageFromCompletion(res)
}

// eslint-disable-next-line no-console
main().then(console.log).catch(console.error)
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async function main() {
),
defineToolFunction<{ location: string }, string>(
toolFunction('getCityCode', 'Get the user\'s city code with search', z.object({
city: z.string().min(1).describe('Get the user\'s city code with search'),
location: z.string().min(1).describe('Get the user\'s city code with search'),
})),
async () => {
return 'NYC'
Expand Down
2 changes: 2 additions & 0 deletions packages/neuri/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"@guiiai/logg": "^1.0.3",
"@shikijs/core": "^1.14.1",
"@shikijs/vscode-textmate": "^9.2.2",
"@valibot/to-json-schema": "^0.2.1",
"ajv": "^8.17.1",
"defu": "^6.1.4",
"execa": "^9.3.1",
Expand All @@ -81,6 +82,7 @@
"unist-util-is": "^6.0.0",
"unist-util-remove": "^4.0.0",
"unist-util-visit": "^5.0.0",
"valibot": "^0.42.1",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.3"
},
Expand Down
19 changes: 15 additions & 4 deletions packages/neuri/src/openai.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { JSONSchema7, JSONSchema7Object } from 'json-schema'
import type OpenAI from 'openai'
import type { JSONSchema7, JSONSchema7Object, JSONSchema7Type } from 'json-schema'
import type { BaseIssue, BaseSchema } from 'valibot'
import { toJsonSchema } from '@valibot/to-json-schema'
import { ZodSchema } from 'zod'
import * as z from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'

export function system(message: string): OpenAI.ChatCompletionSystemMessageParam {
Expand Down Expand Up @@ -173,11 +174,21 @@ export async function invokeFunctionWithTools<P, R>(chatCompletion: OpenAI.Chat.
}
}

function isValibotObjectSchema(schema: any): schema is BaseSchema<unknown, unknown, BaseIssue<unknown>> {
if (typeof schema !== 'object')
return false

return 'type' in schema && schema.type === 'string' && 'reference' in schema && 'expects' in schema && 'entries' in schema && 'message' in schema
}

export function toolFunction(name: string, description: string, parameters: JSONSchema7Object): OpenAI.Chat.ChatCompletionTool
export function toolFunction(name: string, description: string, parameters: ZodSchema): OpenAI.Chat.ChatCompletionTool
export function toolFunction(name: string, description: string, parameters: JSONSchema7Object | ZodSchema): OpenAI.Chat.ChatCompletionTool {
export function toolFunction(name: string, description: string, parameters: BaseSchema<unknown, unknown, BaseIssue<unknown>>): OpenAI.Chat.ChatCompletionTool
export function toolFunction(name: string, description: string, parameters: JSONSchema7 | ZodSchema | BaseSchema<unknown, unknown, BaseIssue<unknown>>): OpenAI.Chat.ChatCompletionTool {
if (parameters instanceof ZodSchema)
parameters = zodToJsonSchema(parameters) as JSONSchema7Object
parameters = zodToJsonSchema(parameters) as JSONSchema7
else if (isValibotObjectSchema(parameters))
parameters = toJsonSchema(parameters) as JSONSchema7

return {
type: 'function',
Expand Down
32 changes: 32 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8fe325d

Please sign in to comment.