diff --git a/src/resources/cached_contents.ts b/src/resources/cached_contents.ts new file mode 100644 index 00000000..7f7bb516 --- /dev/null +++ b/src/resources/cached_contents.ts @@ -0,0 +1,228 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClientError} from '../types'; +import {CachedContent, ListCachedContentsResponse} from '../types'; +import {ApiClient} from './shared/api_client'; + +export function camelToSnake(str: string): string { + return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); +} + +export class CachedContentsClient { + constructor(readonly apiClient: ApiClient) {} + + create(cachedContent: CachedContent): Promise { + return this.apiClient.unaryApiCall( + new URL( + this.apiClient.getBaseUrl() + + '/' + + this.apiClient.getBaseResourePath() + + '/cachedContents' + ), + { + body: JSON.stringify(cachedContent), + }, + 'POST' + ); + } + + update( + cachedContent: CachedContent, + updateMask: string[] + ): Promise { + const url = new URL(this.apiClient.getBaseUrl() + '/' + cachedContent.name); + url.searchParams.append( + 'updateMask', + updateMask.map(e => camelToSnake(e)).join(',') + ); + return this.apiClient.unaryApiCall( + url, + { + body: JSON.stringify(cachedContent), + }, + 'PATCH' + ); + } + + delete(name: string): Promise { + return this.apiClient.unaryApiCall( + new URL(this.apiClient.getBaseUrl() + '/' + name), + {}, + 'DELETE' + ); + } + + list( + pageSize?: number, + pageToken?: string + ): Promise { + const url = new URL( + this.apiClient.getBaseUrl() + '/' + this.apiClient.getBaseResourePath() + ); + if (pageSize) url.searchParams.append('pageSize', String(pageSize)); + if (pageToken) url.searchParams.append('pageToken', pageToken); + return this.apiClient.unaryApiCall(url, {}, 'GET'); + } + + get(name: string): Promise { + return this.apiClient.unaryApiCall( + new URL(this.apiClient.getBaseUrl() + '/' + name), + {}, + 'GET' + ); + } +} + +export function inferFullResourceName( + project: string, + location: string, + cachedContentId: string +): string { + if (cachedContentId.startsWith('projects/')) { + return cachedContentId; + } + if (cachedContentId.startsWith('locations/')) { + return `projects/${project}/${cachedContentId}`; + } + if (cachedContentId.startsWith('cachedContents/')) { + return `projects/${project}/locations/${location}/${cachedContentId}`; + } + if (!cachedContentId.includes('/')) { + return `projects/${project}/locations/${location}/cachedContents/${cachedContentId}`; + } + throw new ClientError( + `Invalid CachedContent.name: ${cachedContentId}. CachedContent.name should start with 'projects/', 'locations/', 'cachedContents/' or is a number type.` + ); +} + +export function inferModelName( + project: string, + location: string, + model?: string +) { + if (!model) { + throw new ClientError('Model name is required.'); + } + if (model.startsWith('publishers/')) { + return `projects/${project}/locations/${location}/${model}`; + } + if (!model.startsWith('projects/')) { + return `projects/${project}/locations/${location}/publishers/google/models/${model}`; + } + return model; +} + +/** + * This class is for managing Vertex AI's CachedContent resource. + * @public + */ +export class CachedContents { + private readonly client: CachedContentsClient; + constructor(client: ApiClient) { + this.client = new CachedContentsClient(client); + } + + /** + * Creates cached content, this call will initialize the cached content in the data storage, and users need to pay for the cache data storage. + * @param cachedContent + * @param parent - Required. The parent resource where the cached content will be created. + */ + create(cachedContent: CachedContent): Promise { + const curatedCachedContent = { + ...cachedContent, + model: inferModelName( + this.client.apiClient.project, + this.client.apiClient.location, + cachedContent.model + ), + } as CachedContent; + return this.client.create(curatedCachedContent); + } + + /** + * Updates cached content configurations + * + * @param updateMask - Required. The list of fields to update. Format: google-fieldmask. See {@link https://cloud.google.com/docs/discovery/type-format} + * @param name - Immutable. Identifier. The server-generated resource name of the cached content Format: projects/{project}/locations/{location}/cachedContents/{cached_content}. + */ + update( + cachedContent: CachedContent, + updateMask: string[] + ): Promise { + if (!cachedContent.name) { + throw new ClientError('Cached content name is required for update.'); + } + if (!updateMask || updateMask.length === 0) { + throw new ClientError( + 'Update mask is required for update. Fields set in cachedContent but not in updateMask will be ignored. Examples: ["ttl"] or ["expireTime"].' + ); + } + const curatedCachedContent = { + ...cachedContent, + name: inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + cachedContent.name + ), + }; + return this.client.update(curatedCachedContent, updateMask); + } + + /** + * Deletes cached content. + * + * @param name - Required. The resource name referring to the cached content. + */ + delete(name: string): Promise { + return this.client.delete( + inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + name + ) + ); + } + + /** + * Lists cached contents in a project. + * + * @param pageSize - Optional. The maximum number of cached contents to return. The service may return fewer than this value. If unspecified, some default (under maximum) number of items will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. + * @param pageToken - Optional. A page token, received from a previous `ListCachedContents` call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to `ListCachedContents` must match the call that provided the page token. + */ + list( + pageSize?: number, + pageToken?: string + ): Promise { + return this.client.list(pageSize, pageToken); + } + + /** + * Gets cached content configurations. + * + * @param name - Required. The resource name referring to the cached content. + */ + get(name: string): Promise { + return this.client.get( + inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + name + ) + ); + } +} diff --git a/src/resources/index.ts b/src/resources/index.ts new file mode 100644 index 00000000..c3c2bb01 --- /dev/null +++ b/src/resources/index.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export { CachedContents } from './cached_contents'; +export { ApiClient } from './shared/api_client'; diff --git a/src/resources/shared/api_client.ts b/src/resources/shared/api_client.ts new file mode 100644 index 00000000..59d48add --- /dev/null +++ b/src/resources/shared/api_client.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {GoogleAuth} from 'google-auth-library'; +import {constants} from '../../util'; +import { + ClientError, + GoogleApiError, + GoogleAuthError, + GoogleGenerativeAIError, +} from '../../types'; + +const AUTHORIZATION_HEADER = 'Authorization'; +const CONTENT_TYPE_HEADER = 'Content-Type'; +const USER_AGENT_HEADER = 'User-Agent'; + +export class ApiClient { + constructor( + readonly project: string, + readonly location: string, + readonly apiVersion: 'v1' | 'v1beta1', + private readonly googleAuth: GoogleAuth + ) {} + + /** + * Gets access token from GoogleAuth. Throws {@link GoogleAuthError} when + * fails. + * @returns Promise of token string. + */ + public fetchToken(): Promise { + const tokenPromise = this.googleAuth.getAccessToken().catch(e => { + throw new GoogleAuthError(constants.CREDENTIAL_ERROR_MESSAGE, e); + }); + return tokenPromise; + } + + getBaseUrl() { + return `https://${this.location}-aiplatform.googleapis.com/${this.apiVersion}`; + } + + getBaseResourePath() { + return `projects/${this.project}/locations/${this.location}`; + } + + async unaryApiCall( + url: URL, + requestInit: RequestInit, + httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE' + ): Promise { + const token = await this.getHeaders(); + return this.apiCall(url.toString(), { + ...requestInit, + method: httpMethod, + headers: token, + }); + } + + private async apiCall( + url: string, + requestInit: RequestInit + ): Promise { + const response = await fetch(url, requestInit).catch(e => { + throw new GoogleGenerativeAIError( + `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, + e + ); + }); + await throwErrorIfNotOK(response, url, requestInit).catch(e => { + throw e; + }); + try { + return await response.json(); + } catch (e) { + throw new GoogleGenerativeAIError(JSON.stringify(response), e as Error); + } + } + + private async getHeaders(): Promise { + const token = await this.fetchToken(); + return new Headers({ + [AUTHORIZATION_HEADER]: `Bearer ${token}`, + [CONTENT_TYPE_HEADER]: 'application/json', + [USER_AGENT_HEADER]: constants.USER_AGENT, + }); + } +} + +async function throwErrorIfNotOK( + response: Response | undefined, + url: string, + requestInit: RequestInit +) { + if (response === undefined) { + throw new GoogleGenerativeAIError('response is undefined'); + } + if (!response.ok) { + const status: number = response.status; + const statusText: string = response.statusText; + let errorBody; + if (response.headers.get('content-type')?.includes('application/json')) { + errorBody = await response.json(); + } else { + errorBody = { + error: { + message: `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, + code: response.status, + status: response.statusText, + }, + }; + } + const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify( + errorBody + )}`; + if (status >= 400 && status < 500) { + const error = new ClientError( + errorMessage, + new GoogleApiError( + errorBody.error.message, + errorBody.error.code, + errorBody.error.status, + errorBody.error.details + ) + ); + throw error; + } + throw new GoogleGenerativeAIError(errorMessage); + } +} diff --git a/src/types/content.ts b/src/types/content.ts index 1babe1c7..0595bfe8 100644 --- a/src/types/content.ts +++ b/src/types/content.ts @@ -1061,3 +1061,91 @@ export interface RequestOptions { */ customHeaders?: Headers; } + + +/** + * A resource used in LLM queries for users to explicitly specify + * what to cache and how to cache. + */ +export interface CachedContent { + /** + * Immutable. Identifier. The server-generated resource name of the cached content. + * Format: projects/{project}/locations/{location}/cachedContents/{cached_content} + */ + name?: string; + + /** Optional. Immutable. The user-generated meaningful display name of the cached content. */ + displayName?: string; + + /** + * Immutable. The name of the publisher model to use for cached content. + * Format: projects/{project}/locations/{location}/publishers/{publisher}/models/{model} + */ + model?: string; + + /** Developer set system instruction. Currently, text only. */ + systemInstruction?: Content; + + /** Optional. Input only. Immutable. The content to cache. */ + contents?: Content[]; + + /** Optional. Input only. Immutable. A list of `Tools` the model may use to generate the next response. */ + tools?: Tool[]; + + /** Optional. Input only. Immutable. Tool config. This config is shared for all tools. */ + toolConfig?: ToolConfig; + + /** + * Output only. Creatation time of the cache entry. + * Format: google-datetime. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + createTime?: string; + + /** + * Output only. When the cache entry was last updated in UTC time. + * Format: google-datetime. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + updateTime?: string; + + /** Output only. Metadata on the usage of the cached content. */ + usageMetadata?: CachedContentUsageMetadata; + + /** + * Timestamp of when this resource is considered expired. + * This is *always* provided on output, regardless of what was sent on input. + */ + expireTime?: string; + + /** + * Input only. The TTL seconds for this resource. The expiration time + * is computed: now + TTL. + * Format: google-duration. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + ttl?: string; +} + +/** Metadata on the usage of the cached content. */ +export interface CachedContentUsageMetadata { + /** Total number of tokens that the cached content consumes. */ + totalTokenCount?: number; + + /** Number of text characters. */ + textCount?: number; + + /** Number of images. */ + imageCount?: number; + + /** Duration of video in seconds. */ + videoDurationSeconds?: number; + + /** Duration of audio in seconds. */ + audioDurationSeconds?: number; +} + +/** Response with a list of CachedContents. */ +export interface ListCachedContentsResponse { + /** List of cached contents. */ + cachedContents?: CachedContent[]; + /** A token, which can be sent as `page_token` to retrieve the next page. If this field is omitted, there are no subsequent pages. */ + nextPageToken?: string; +} diff --git a/src/vertex_ai.ts b/src/vertex_ai.ts index 27194ee0..cb2d82ac 100644 --- a/src/vertex_ai.ts +++ b/src/vertex_ai.ts @@ -26,7 +26,7 @@ import { VertexInit, } from './types/content'; import {GoogleAuthError, IllegalArgumentError} from './types/errors'; - +import * as Resources from './resources'; /** * The `VertexAI` class is the base class for authenticating to Vertex AI. * To use Vertex AI's generative AI models, use the `getGenerativeModel` method. @@ -149,6 +149,9 @@ class VertexAIPreview { private readonly googleAuth: GoogleAuth; private readonly apiEndpoint?: string; + readonly apiClient: Resources.ApiClient; + readonly cachedContents: Resources.CachedContents; + /** * @constructor * @param project - The Google Cloud project to use for the request @@ -174,6 +177,14 @@ class VertexAIPreview { this.location = location; this.googleAuth = googleAuth; this.apiEndpoint = apiEndpoint; + + this.apiClient = new Resources.ApiClient( + this.project, + this.location, + 'v1beta1', + this.googleAuth + ); + this.cachedContents = new Resources.CachedContents(this.apiClient); } /**