Skip to content

Commit

Permalink
feat: Add CachedContent resource to Vertex AI client library.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673208074
  • Loading branch information
happy-qiao authored and copybara-github committed Sep 11, 2024
1 parent 3e5e1bf commit 663a977
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 1 deletion.
228 changes: 228 additions & 0 deletions src/resources/cached_contents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {ClientError} from '../types';
import {CachedContent, ListCachedContentsResponse} from '../types';
import {ApiClient} from './shared/api_client';

export function camelToSnake(str: string): string {
return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`);
}

export class CachedContentsClient {
constructor(readonly apiClient: ApiClient) {}

create(cachedContent: CachedContent): Promise<CachedContent> {
return this.apiClient.unaryApiCall(
new URL(
this.apiClient.getBaseUrl() +
'/' +
this.apiClient.getBaseResourePath() +
'/cachedContents'
),
{
body: JSON.stringify(cachedContent),
},
'POST'
);
}

update(
cachedContent: CachedContent,
updateMask: string[]
): Promise<CachedContent> {
const url = new URL(this.apiClient.getBaseUrl() + '/' + cachedContent.name);
url.searchParams.append(
'updateMask',
updateMask.map(e => camelToSnake(e)).join(',')
);
return this.apiClient.unaryApiCall(
url,
{
body: JSON.stringify(cachedContent),
},
'PATCH'
);
}

delete(name: string): Promise<void> {
return this.apiClient.unaryApiCall(
new URL(this.apiClient.getBaseUrl() + '/' + name),
{},
'DELETE'
);
}

list(
pageSize?: number,
pageToken?: string
): Promise<ListCachedContentsResponse> {
const url = new URL(
this.apiClient.getBaseUrl() + '/' + this.apiClient.getBaseResourePath()
);
if (pageSize) url.searchParams.append('pageSize', String(pageSize));
if (pageToken) url.searchParams.append('pageToken', pageToken);
return this.apiClient.unaryApiCall(url, {}, 'GET');
}

get(name: string): Promise<CachedContent> {
return this.apiClient.unaryApiCall(
new URL(this.apiClient.getBaseUrl() + '/' + name),
{},
'GET'
);
}
}

export function inferFullResourceName(
project: string,
location: string,
cachedContentId: string
): string {
if (cachedContentId.startsWith('projects/')) {
return cachedContentId;
}
if (cachedContentId.startsWith('locations/')) {
return `projects/${project}/${cachedContentId}`;
}
if (cachedContentId.startsWith('cachedContents/')) {
return `projects/${project}/locations/${location}/${cachedContentId}`;
}
if (!cachedContentId.includes('/')) {
return `projects/${project}/locations/${location}/cachedContents/${cachedContentId}`;
}
throw new ClientError(
`Invalid CachedContent.name: ${cachedContentId}. CachedContent.name should start with 'projects/', 'locations/', 'cachedContents/' or is a number type.`
);
}

export function inferModelName(
project: string,
location: string,
model?: string
) {
if (!model) {
throw new ClientError('Model name is required.');
}
if (model.startsWith('publishers/')) {
return `projects/${project}/locations/${location}/${model}`;
}
if (!model.startsWith('projects/')) {
return `projects/${project}/locations/${location}/publishers/google/models/${model}`;
}
return model;
}

/**
* This class is for managing Vertex AI's CachedContent resource.
* @public
*/
export class CachedContents {
private readonly client: CachedContentsClient;
constructor(client: ApiClient) {
this.client = new CachedContentsClient(client);
}

/**
* Creates cached content, this call will initialize the cached content in the data storage, and users need to pay for the cache data storage.
* @param cachedContent
* @param parent - Required. The parent resource where the cached content will be created.
*/
create(cachedContent: CachedContent): Promise<CachedContent> {
const curatedCachedContent = {
...cachedContent,
model: inferModelName(
this.client.apiClient.project,
this.client.apiClient.location,
cachedContent.model
),
} as CachedContent;
return this.client.create(curatedCachedContent);
}

/**
* Updates cached content configurations
*
* @param updateMask - Required. The list of fields to update. Format: google-fieldmask. See {@link https://cloud.google.com/docs/discovery/type-format}
* @param name - Immutable. Identifier. The server-generated resource name of the cached content Format: projects/{project}/locations/{location}/cachedContents/{cached_content}.
*/
update(
cachedContent: CachedContent,
updateMask: string[]
): Promise<CachedContent> {
if (!cachedContent.name) {
throw new ClientError('Cached content name is required for update.');
}
if (!updateMask || updateMask.length === 0) {
throw new ClientError(
'Update mask is required for update. Fields set in cachedContent but not in updateMask will be ignored. Examples: ["ttl"] or ["expireTime"].'
);
}
const curatedCachedContent = {
...cachedContent,
name: inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
cachedContent.name
),
};
return this.client.update(curatedCachedContent, updateMask);
}

/**
* Deletes cached content.
*
* @param name - Required. The resource name referring to the cached content.
*/
delete(name: string): Promise<void> {
return this.client.delete(
inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
name
)
);
}

/**
* Lists cached contents in a project.
*
* @param pageSize - Optional. The maximum number of cached contents to return. The service may return fewer than this value. If unspecified, some default (under maximum) number of items will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000.
* @param pageToken - Optional. A page token, received from a previous `ListCachedContents` call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to `ListCachedContents` must match the call that provided the page token.
*/
list(
pageSize?: number,
pageToken?: string
): Promise<ListCachedContentsResponse> {
return this.client.list(pageSize, pageToken);
}

/**
* Gets cached content configurations.
*
* @param name - Required. The resource name referring to the cached content.
*/
get(name: string): Promise<CachedContent> {
return this.client.get(
inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
name
)
);
}
}
19 changes: 19 additions & 0 deletions src/resources/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

export { CachedContents } from './cached_contents';
export { ApiClient } from './shared/api_client';
142 changes: 142 additions & 0 deletions src/resources/shared/api_client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {GoogleAuth} from 'google-auth-library';
import {constants} from '../../util';
import {
ClientError,
GoogleApiError,
GoogleAuthError,
GoogleGenerativeAIError,
} from '../../types';

const AUTHORIZATION_HEADER = 'Authorization';
const CONTENT_TYPE_HEADER = 'Content-Type';
const USER_AGENT_HEADER = 'User-Agent';

export class ApiClient {
constructor(
readonly project: string,
readonly location: string,
readonly apiVersion: 'v1' | 'v1beta1',
private readonly googleAuth: GoogleAuth
) {}

/**
* Gets access token from GoogleAuth. Throws {@link GoogleAuthError} when
* fails.
* @returns Promise of token string.
*/
public fetchToken(): Promise<string | null | undefined> {
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
throw new GoogleAuthError(constants.CREDENTIAL_ERROR_MESSAGE, e);
});
return tokenPromise;
}

getBaseUrl() {
return `https://${this.location}-aiplatform.googleapis.com/${this.apiVersion}`;
}

getBaseResourePath() {
return `projects/${this.project}/locations/${this.location}`;
}

async unaryApiCall(
url: URL,
requestInit: RequestInit,
httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE'
): Promise<any> {
const token = await this.getHeaders();
return this.apiCall(url.toString(), {
...requestInit,
method: httpMethod,
headers: token,
});
}

private async apiCall(
url: string,
requestInit: RequestInit
): Promise<Response> {
const response = await fetch(url, requestInit).catch(e => {
throw new GoogleGenerativeAIError(
`exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`,
e
);
});
await throwErrorIfNotOK(response, url, requestInit).catch(e => {
throw e;
});
try {
return await response.json();
} catch (e) {
throw new GoogleGenerativeAIError(JSON.stringify(response), e as Error);
}
}

private async getHeaders(): Promise<Headers> {
const token = await this.fetchToken();
return new Headers({
[AUTHORIZATION_HEADER]: `Bearer ${token}`,
[CONTENT_TYPE_HEADER]: 'application/json',
[USER_AGENT_HEADER]: constants.USER_AGENT,
});
}
}

async function throwErrorIfNotOK(
response: Response | undefined,
url: string,
requestInit: RequestInit
) {
if (response === undefined) {
throw new GoogleGenerativeAIError('response is undefined');
}
if (!response.ok) {
const status: number = response.status;
const statusText: string = response.statusText;
let errorBody;
if (response.headers.get('content-type')?.includes('application/json')) {
errorBody = await response.json();
} else {
errorBody = {
error: {
message: `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`,
code: response.status,
status: response.statusText,
},
};
}
const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify(
errorBody
)}`;
if (status >= 400 && status < 500) {
const error = new ClientError(
errorMessage,
new GoogleApiError(
errorBody.error.message,
errorBody.error.code,
errorBody.error.status,
errorBody.error.details
)
);
throw error;
}
throw new GoogleGenerativeAIError(errorMessage);
}
}
Loading

0 comments on commit 663a977

Please sign in to comment.