Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CachedContent resource to Vertex AI client library. #424

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions src/resources/cached_contents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {ClientError} from '../types';
import {CachedContent, ListCachedContentsResponse} from '../types';
import {ApiClient} from './shared/api_client';

export function camelToSnake(str: string): string {
return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`);
}

export class CachedContentsClient {
constructor(readonly apiClient: ApiClient) {}

create(cachedContent: CachedContent): Promise<CachedContent> {
return this.apiClient.unaryApiCall(
new URL(
this.apiClient.getBaseUrl() +
'/' +
this.apiClient.getBaseResourePath() +
'/cachedContents'
),
{
body: JSON.stringify(cachedContent),
},
'POST'
);
}

update(
cachedContent: CachedContent,
updateMask: string[]
): Promise<CachedContent> {
const url = new URL(this.apiClient.getBaseUrl() + '/' + cachedContent.name);
url.searchParams.append(
'updateMask',
updateMask.map(e => camelToSnake(e)).join(',')
);
return this.apiClient.unaryApiCall(
url,
{
body: JSON.stringify(cachedContent),
},
'PATCH'
);
}

delete(name: string): Promise<void> {
return this.apiClient.unaryApiCall(
new URL(this.apiClient.getBaseUrl() + '/' + name),
{},
'DELETE'
);
}

list(
pageSize?: number,
pageToken?: string
): Promise<ListCachedContentsResponse> {
const url = new URL(
this.apiClient.getBaseUrl() + '/' + this.apiClient.getBaseResourePath()
);
if (pageSize) url.searchParams.append('pageSize', String(pageSize));
if (pageToken) url.searchParams.append('pageToken', pageToken);
return this.apiClient.unaryApiCall(url, {}, 'GET');
}

get(name: string): Promise<CachedContent> {
return this.apiClient.unaryApiCall(
new URL(this.apiClient.getBaseUrl() + '/' + name),
{},
'GET'
);
}
}

export function inferFullResourceName(
project: string,
location: string,
cachedContentId: string
): string {
if (cachedContentId.startsWith('projects/')) {
return cachedContentId;
}
if (cachedContentId.startsWith('locations/')) {
return `projects/${project}/${cachedContentId}`;
}
if (cachedContentId.startsWith('cachedContents/')) {
return `projects/${project}/locations/${location}/${cachedContentId}`;
}
if (!cachedContentId.includes('/')) {
return `projects/${project}/locations/${location}/cachedContents/${cachedContentId}`;
}
throw new ClientError(
`Invalid CachedContent.name: ${cachedContentId}. CachedContent.name should start with 'projects/', 'locations/', 'cachedContents/' or is a number type.`
);
}

export function inferModelName(
project: string,
location: string,
model?: string
) {
if (!model) {
throw new ClientError('Model name is required.');
}
if (model.startsWith('publishers/')) {
return `projects/${project}/locations/${location}/${model}`;
}
if (!model.startsWith('projects/')) {
return `projects/${project}/locations/${location}/publishers/google/models/${model}`;
}
return model;
}

/**
* This class is for managing Vertex AI's CachedContent resource.
* @public
*/
export class CachedContents {
private readonly client: CachedContentsClient;
constructor(client: ApiClient) {
this.client = new CachedContentsClient(client);
}

/**
* Creates cached content, this call will initialize the cached content in the data storage, and users need to pay for the cache data storage.
* @param cachedContent
* @param parent - Required. The parent resource where the cached content will be created.
*/
create(cachedContent: CachedContent): Promise<CachedContent> {
const curatedCachedContent = {
...cachedContent,
model: inferModelName(
this.client.apiClient.project,
this.client.apiClient.location,
cachedContent.model
),
} as CachedContent;
return this.client.create(curatedCachedContent);
}

/**
* Updates cached content configurations
*
* @param updateMask - Required. The list of fields to update. Format: google-fieldmask. See {@link https://cloud.google.com/docs/discovery/type-format}
* @param name - Immutable. Identifier. The server-generated resource name of the cached content Format: projects/{project}/locations/{location}/cachedContents/{cached_content}.
*/
update(
cachedContent: CachedContent,
updateMask: string[]
): Promise<CachedContent> {
if (!cachedContent.name) {
throw new ClientError('Cached content name is required for update.');
}
if (!updateMask || updateMask.length === 0) {
throw new ClientError(
'Update mask is required for update. Fields set in cachedContent but not in updateMask will be ignored. Examples: ["ttl"] or ["expireTime"].'
);
}
const curatedCachedContent = {
...cachedContent,
name: inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
cachedContent.name
),
};
return this.client.update(curatedCachedContent, updateMask);
}

/**
* Deletes cached content.
*
* @param name - Required. The resource name referring to the cached content.
*/
delete(name: string): Promise<void> {
return this.client.delete(
inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
name
)
);
}

/**
* Lists cached contents in a project.
*
* @param pageSize - Optional. The maximum number of cached contents to return. The service may return fewer than this value. If unspecified, some default (under maximum) number of items will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000.
* @param pageToken - Optional. A page token, received from a previous `ListCachedContents` call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to `ListCachedContents` must match the call that provided the page token.
*/
list(
pageSize?: number,
pageToken?: string
): Promise<ListCachedContentsResponse> {
return this.client.list(pageSize, pageToken);
}

/**
* Gets cached content configurations.
*
* @param name - Required. The resource name referring to the cached content.
*/
get(name: string): Promise<CachedContent> {
return this.client.get(
inferFullResourceName(
this.client.apiClient.project,
this.client.apiClient.location,
name
)
);
}
}
19 changes: 19 additions & 0 deletions src/resources/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

export {CachedContents} from './cached_contents';
export {ApiClient} from './shared/api_client';
142 changes: 142 additions & 0 deletions src/resources/shared/api_client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {GoogleAuth} from 'google-auth-library';
import {constants} from '../../util';
import {
ClientError,
GoogleApiError,
GoogleAuthError,
GoogleGenerativeAIError,
} from '../../types';

const AUTHORIZATION_HEADER = 'Authorization';
const CONTENT_TYPE_HEADER = 'Content-Type';
const USER_AGENT_HEADER = 'User-Agent';

export class ApiClient {
constructor(
readonly project: string,
readonly location: string,
readonly apiVersion: 'v1' | 'v1beta1',
private readonly googleAuth: GoogleAuth
) {}

/**
* Gets access token from GoogleAuth. Throws {@link GoogleAuthError} when
* fails.
* @returns Promise of token string.
*/
private fetchToken(): Promise<string | null | undefined> {
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
throw new GoogleAuthError(constants.CREDENTIAL_ERROR_MESSAGE, e);
});
return tokenPromise;
}

getBaseUrl() {
return `https://${this.location}-aiplatform.googleapis.com/${this.apiVersion}`;
}

getBaseResourePath() {
return `projects/${this.project}/locations/${this.location}`;
}

async unaryApiCall(
url: URL,
requestInit: RequestInit,
httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE'
): Promise<any> {
const token = await this.getHeaders();
return this.apiCall(url.toString(), {
...requestInit,
method: httpMethod,
headers: token,
});
}

private async apiCall(
url: string,
requestInit: RequestInit
): Promise<Response> {
const response = await fetch(url, requestInit).catch(e => {
throw new GoogleGenerativeAIError(
`exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`,
e
);
});
await throwErrorIfNotOK(response, url, requestInit).catch(e => {
throw e;
});
try {
return await response.json();
} catch (e) {
throw new GoogleGenerativeAIError(JSON.stringify(response), e as Error);
}
}

private async getHeaders(): Promise<Headers> {
const token = await this.fetchToken();
return new Headers({
[AUTHORIZATION_HEADER]: `Bearer ${token}`,
[CONTENT_TYPE_HEADER]: 'application/json',
[USER_AGENT_HEADER]: constants.USER_AGENT,
});
}
}

async function throwErrorIfNotOK(
response: Response | undefined,
url: string,
requestInit: RequestInit
) {
if (response === undefined) {
throw new GoogleGenerativeAIError('response is undefined');
}
if (!response.ok) {
const status: number = response.status;
const statusText: string = response.statusText;
let errorBody;
if (response.headers.get('content-type')?.includes('application/json')) {
errorBody = await response.json();
} else {
errorBody = {
error: {
message: `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`,
code: response.status,
status: response.statusText,
},
};
}
const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify(
errorBody
)}`;
if (status >= 400 && status < 500) {
const error = new ClientError(
errorMessage,
new GoogleApiError(
errorBody.error.message,
errorBody.error.code,
errorBody.error.status,
errorBody.error.details
)
);
throw error;
}
throw new GoogleGenerativeAIError(errorMessage);
}
}
Loading
Loading