Skip to content

Commit

Permalink
fix(js/ai): skip response schema validation on tool calls (#1597)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Jan 16, 2025
1 parent 6b541a3 commit 51ebf4e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
5 changes: 4 additions & 1 deletion js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,14 @@ async function generate(
);

// Throw an error if the response is not usable.
response.assertValid(request);
response.assertValid();
const message = response.message!; // would have thrown if no message

const toolCalls = message.content.filter((part) => !!part.toolRequest);
if (rawRequest.returnToolRequests || toolCalls.length === 0) {
if (toolCalls.length === 0) {
response.assertValidSchema(request);
}
return response.toJSON();
}
const maxIterations = rawRequest.maxTurns ?? 5;
Expand Down
10 changes: 8 additions & 2 deletions js/ai/src/generate/response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export class GenerateResponse<O = unknown> implements ModelResponseData {
/**
* Throws an error if the response does not contain valid output.
*/
assertValid(request?: GenerateRequest): void {
assertValid(): void {
if (this.finishReason === 'blocked') {
throw new GenerationBlockedError(
this,
Expand All @@ -90,7 +90,12 @@ export class GenerateResponse<O = unknown> implements ModelResponseData {
`Model did not generate a message. Finish reason: '${this.finishReason}': ${this.finishMessage}`
);
}
}

/**
* Throws an error if the response does not conform to expected schema.
*/
assertValidSchema(request?: GenerateRequest): void {
if (request?.output?.schema || this.request?.output?.schema) {
const o = this.output;
parseSchema(o, {
Expand All @@ -101,7 +106,8 @@ export class GenerateResponse<O = unknown> implements ModelResponseData {

isValid(request?: GenerateRequest): boolean {
try {
this.assertValid(request);
this.assertValid();
this.assertValidSchema(request);
return true;
} catch (e) {
return false;
Expand Down
4 changes: 2 additions & 2 deletions js/ai/tests/generate/response_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ describe('GenerateResponse', () => {

assert.throws(
() => {
response.assertValid(request);
response.assertValidSchema(request);
},
(err: unknown) => {
return err instanceof Error && err.message.includes('must be number');
Expand Down Expand Up @@ -186,7 +186,7 @@ describe('GenerateResponse', () => {
};

assert.doesNotThrow(() => {
response.assertValid(request);
response.assertValidSchema(request);
});
});
});
Expand Down
55 changes: 55 additions & 0 deletions js/genkit/tests/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

import { z } from '@genkit-ai/core';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { modelRef } from '../../ai/src/model';
Expand Down Expand Up @@ -359,6 +360,60 @@ describe('generate', () => {
);
});

it('call the tool with output schema', async () => {
const schema = z.object({
foo: z.string(),
});

ai.defineTool(
{
name: 'testTool',
description: 'description',
inputSchema: schema,
outputSchema: schema,
},
async () => {
return {
foo: 'bar',
};
}
);

// first response be tools call, the subsequent just text response from agent b.
let reqCounter = 0;
pm.handleResponse = async (req, sc) => {
return {
message: {
role: 'model',
content: [
reqCounter++ === 0
? {
toolRequest: {
name: 'testTool',
input: { foo: 'fromTool' },
ref: 'ref123',
},
}
: {
text: "```\n{foo: 'fromModel'}\n```",
},
],
},
};
};

const { text, output } = await ai.generate({
output: { schema },
prompt: 'call the tool',
tools: ['testTool'],
});

assert.strictEqual(text, "```\n{foo: 'fromModel'}\n```");
assert.deepStrictEqual(output, {
foo: 'fromModel',
});
});

it('throws when exceeding max tool call iterations', async () => {
ai.defineTool(
{ name: 'testTool', description: 'description' },
Expand Down

0 comments on commit 51ebf4e

Please sign in to comment.