-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add CachedContent resource to Vertex AI client library.
PiperOrigin-RevId: 673208074
- Loading branch information
1 parent
3e5e1bf
commit 663a977
Showing
5 changed files
with
489 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
/** | ||
* @license | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
import {ClientError} from '../types'; | ||
import {CachedContent, ListCachedContentsResponse} from '../types'; | ||
import {ApiClient} from './shared/api_client'; | ||
|
||
export function camelToSnake(str: string): string { | ||
return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); | ||
} | ||
|
||
export class CachedContentsClient { | ||
constructor(readonly apiClient: ApiClient) {} | ||
|
||
create(cachedContent: CachedContent): Promise<CachedContent> { | ||
return this.apiClient.unaryApiCall( | ||
new URL( | ||
this.apiClient.getBaseUrl() + | ||
'/' + | ||
this.apiClient.getBaseResourePath() + | ||
'/cachedContents' | ||
), | ||
{ | ||
body: JSON.stringify(cachedContent), | ||
}, | ||
'POST' | ||
); | ||
} | ||
|
||
update( | ||
cachedContent: CachedContent, | ||
updateMask: string[] | ||
): Promise<CachedContent> { | ||
const url = new URL(this.apiClient.getBaseUrl() + '/' + cachedContent.name); | ||
url.searchParams.append( | ||
'updateMask', | ||
updateMask.map(e => camelToSnake(e)).join(',') | ||
); | ||
return this.apiClient.unaryApiCall( | ||
url, | ||
{ | ||
body: JSON.stringify(cachedContent), | ||
}, | ||
'PATCH' | ||
); | ||
} | ||
|
||
delete(name: string): Promise<void> { | ||
return this.apiClient.unaryApiCall( | ||
new URL(this.apiClient.getBaseUrl() + '/' + name), | ||
{}, | ||
'DELETE' | ||
); | ||
} | ||
|
||
list( | ||
pageSize?: number, | ||
pageToken?: string | ||
): Promise<ListCachedContentsResponse> { | ||
const url = new URL( | ||
this.apiClient.getBaseUrl() + '/' + this.apiClient.getBaseResourePath() | ||
); | ||
if (pageSize) url.searchParams.append('pageSize', String(pageSize)); | ||
if (pageToken) url.searchParams.append('pageToken', pageToken); | ||
return this.apiClient.unaryApiCall(url, {}, 'GET'); | ||
} | ||
|
||
get(name: string): Promise<CachedContent> { | ||
return this.apiClient.unaryApiCall( | ||
new URL(this.apiClient.getBaseUrl() + '/' + name), | ||
{}, | ||
'GET' | ||
); | ||
} | ||
} | ||
|
||
export function inferFullResourceName( | ||
project: string, | ||
location: string, | ||
cachedContentId: string | ||
): string { | ||
if (cachedContentId.startsWith('projects/')) { | ||
return cachedContentId; | ||
} | ||
if (cachedContentId.startsWith('locations/')) { | ||
return `projects/${project}/${cachedContentId}`; | ||
} | ||
if (cachedContentId.startsWith('cachedContents/')) { | ||
return `projects/${project}/locations/${location}/${cachedContentId}`; | ||
} | ||
if (!cachedContentId.includes('/')) { | ||
return `projects/${project}/locations/${location}/cachedContents/${cachedContentId}`; | ||
} | ||
throw new ClientError( | ||
`Invalid CachedContent.name: ${cachedContentId}. CachedContent.name should start with 'projects/', 'locations/', 'cachedContents/' or is a number type.` | ||
); | ||
} | ||
|
||
export function inferModelName( | ||
project: string, | ||
location: string, | ||
model?: string | ||
) { | ||
if (!model) { | ||
throw new ClientError('Model name is required.'); | ||
} | ||
if (model.startsWith('publishers/')) { | ||
return `projects/${project}/locations/${location}/${model}`; | ||
} | ||
if (!model.startsWith('projects/')) { | ||
return `projects/${project}/locations/${location}/publishers/google/models/${model}`; | ||
} | ||
return model; | ||
} | ||
|
||
/** | ||
* This class is for managing Vertex AI's CachedContent resource. | ||
* @public | ||
*/ | ||
export class CachedContents { | ||
private readonly client: CachedContentsClient; | ||
constructor(client: ApiClient) { | ||
this.client = new CachedContentsClient(client); | ||
} | ||
|
||
/** | ||
* Creates cached content, this call will initialize the cached content in the data storage, and users need to pay for the cache data storage. | ||
* @param cachedContent | ||
* @param parent - Required. The parent resource where the cached content will be created. | ||
*/ | ||
create(cachedContent: CachedContent): Promise<CachedContent> { | ||
const curatedCachedContent = { | ||
...cachedContent, | ||
model: inferModelName( | ||
this.client.apiClient.project, | ||
this.client.apiClient.location, | ||
cachedContent.model | ||
), | ||
} as CachedContent; | ||
return this.client.create(curatedCachedContent); | ||
} | ||
|
||
/** | ||
* Updates cached content configurations | ||
* | ||
* @param updateMask - Required. The list of fields to update. Format: google-fieldmask. See {@link https://cloud.google.com/docs/discovery/type-format} | ||
* @param name - Immutable. Identifier. The server-generated resource name of the cached content Format: projects/{project}/locations/{location}/cachedContents/{cached_content}. | ||
*/ | ||
update( | ||
cachedContent: CachedContent, | ||
updateMask: string[] | ||
): Promise<CachedContent> { | ||
if (!cachedContent.name) { | ||
throw new ClientError('Cached content name is required for update.'); | ||
} | ||
if (!updateMask || updateMask.length === 0) { | ||
throw new ClientError( | ||
'Update mask is required for update. Fields set in cachedContent but not in updateMask will be ignored. Examples: ["ttl"] or ["expireTime"].' | ||
); | ||
} | ||
const curatedCachedContent = { | ||
...cachedContent, | ||
name: inferFullResourceName( | ||
this.client.apiClient.project, | ||
this.client.apiClient.location, | ||
cachedContent.name | ||
), | ||
}; | ||
return this.client.update(curatedCachedContent, updateMask); | ||
} | ||
|
||
/** | ||
* Deletes cached content. | ||
* | ||
* @param name - Required. The resource name referring to the cached content. | ||
*/ | ||
delete(name: string): Promise<void> { | ||
return this.client.delete( | ||
inferFullResourceName( | ||
this.client.apiClient.project, | ||
this.client.apiClient.location, | ||
name | ||
) | ||
); | ||
} | ||
|
||
/** | ||
* Lists cached contents in a project. | ||
* | ||
* @param pageSize - Optional. The maximum number of cached contents to return. The service may return fewer than this value. If unspecified, some default (under maximum) number of items will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. | ||
* @param pageToken - Optional. A page token, received from a previous `ListCachedContents` call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to `ListCachedContents` must match the call that provided the page token. | ||
*/ | ||
list( | ||
pageSize?: number, | ||
pageToken?: string | ||
): Promise<ListCachedContentsResponse> { | ||
return this.client.list(pageSize, pageToken); | ||
} | ||
|
||
/** | ||
* Gets cached content configurations. | ||
* | ||
* @param name - Required. The resource name referring to the cached content. | ||
*/ | ||
get(name: string): Promise<CachedContent> { | ||
return this.client.get( | ||
inferFullResourceName( | ||
this.client.apiClient.project, | ||
this.client.apiClient.location, | ||
name | ||
) | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/** | ||
* @license | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
export { CachedContents } from './cached_contents'; | ||
export { ApiClient } from './shared/api_client'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/** | ||
* @license | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
import {GoogleAuth} from 'google-auth-library'; | ||
import {constants} from '../../util'; | ||
import { | ||
ClientError, | ||
GoogleApiError, | ||
GoogleAuthError, | ||
GoogleGenerativeAIError, | ||
} from '../../types'; | ||
|
||
const AUTHORIZATION_HEADER = 'Authorization'; | ||
const CONTENT_TYPE_HEADER = 'Content-Type'; | ||
const USER_AGENT_HEADER = 'User-Agent'; | ||
|
||
export class ApiClient { | ||
constructor( | ||
readonly project: string, | ||
readonly location: string, | ||
readonly apiVersion: 'v1' | 'v1beta1', | ||
private readonly googleAuth: GoogleAuth | ||
) {} | ||
|
||
/** | ||
* Gets access token from GoogleAuth. Throws {@link GoogleAuthError} when | ||
* fails. | ||
* @returns Promise of token string. | ||
*/ | ||
public fetchToken(): Promise<string | null | undefined> { | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new GoogleAuthError(constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
}); | ||
return tokenPromise; | ||
} | ||
|
||
getBaseUrl() { | ||
return `https://${this.location}-aiplatform.googleapis.com/${this.apiVersion}`; | ||
} | ||
|
||
getBaseResourePath() { | ||
return `projects/${this.project}/locations/${this.location}`; | ||
} | ||
|
||
async unaryApiCall( | ||
url: URL, | ||
requestInit: RequestInit, | ||
httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE' | ||
): Promise<any> { | ||
const token = await this.getHeaders(); | ||
return this.apiCall(url.toString(), { | ||
...requestInit, | ||
method: httpMethod, | ||
headers: token, | ||
}); | ||
} | ||
|
||
private async apiCall( | ||
url: string, | ||
requestInit: RequestInit | ||
): Promise<Response> { | ||
const response = await fetch(url, requestInit).catch(e => { | ||
throw new GoogleGenerativeAIError( | ||
`exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, | ||
e | ||
); | ||
}); | ||
await throwErrorIfNotOK(response, url, requestInit).catch(e => { | ||
throw e; | ||
}); | ||
try { | ||
return await response.json(); | ||
} catch (e) { | ||
throw new GoogleGenerativeAIError(JSON.stringify(response), e as Error); | ||
} | ||
} | ||
|
||
private async getHeaders(): Promise<Headers> { | ||
const token = await this.fetchToken(); | ||
return new Headers({ | ||
[AUTHORIZATION_HEADER]: `Bearer ${token}`, | ||
[CONTENT_TYPE_HEADER]: 'application/json', | ||
[USER_AGENT_HEADER]: constants.USER_AGENT, | ||
}); | ||
} | ||
} | ||
|
||
async function throwErrorIfNotOK( | ||
response: Response | undefined, | ||
url: string, | ||
requestInit: RequestInit | ||
) { | ||
if (response === undefined) { | ||
throw new GoogleGenerativeAIError('response is undefined'); | ||
} | ||
if (!response.ok) { | ||
const status: number = response.status; | ||
const statusText: string = response.statusText; | ||
let errorBody; | ||
if (response.headers.get('content-type')?.includes('application/json')) { | ||
errorBody = await response.json(); | ||
} else { | ||
errorBody = { | ||
error: { | ||
message: `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, | ||
code: response.status, | ||
status: response.statusText, | ||
}, | ||
}; | ||
} | ||
const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify( | ||
errorBody | ||
)}`; | ||
if (status >= 400 && status < 500) { | ||
const error = new ClientError( | ||
errorMessage, | ||
new GoogleApiError( | ||
errorBody.error.message, | ||
errorBody.error.code, | ||
errorBody.error.status, | ||
errorBody.error.details | ||
) | ||
); | ||
throw error; | ||
} | ||
throw new GoogleGenerativeAIError(errorMessage); | ||
} | ||
} |
Oops, something went wrong.