From 9160d96a1859175c3333f769f5725a8f17d51599 Mon Sep 17 00:00:00 2001 From: JinmingYang <2214962083@qq.com> Date: Wed, 25 Sep 2024 02:16:02 +0800 Subject: [PATCH] feat: add local embedding model --- .vscode/settings.json | 22 +- package.json | 2 - pnpm-lock.yaml | 93 +- .../@xenova/transformers/transformers.js | 24 + .../@xenova/transformers/utils/image.js | 811 ++++++++++++++++++ .../@xenova/transformers/utils/sharp.js | 49 ++ .../fix-package/onnxruntime-node/binding.js | 12 + scripts/fix-package/vectordb/native.js | 31 - .../ai/embeddings/embedding-manager.ts | 116 +++ .../embeddings/transformer-js-embeddings.ts | 87 ++ src/extension/ai/embeddings/types.ts | 26 + src/extension/constants.ts | 2 - src/extension/file-utils/paths.ts | 29 +- src/extension/registers/index.ts | 8 +- src/extension/registers/model-register.ts | 31 + src/extension/registers/register-manager.ts | 15 +- .../vectordb/code-chunks-index-table.ts | 58 +- .../vectordb/codebase-indexer.ts | 19 +- vite.config.mts | 100 ++- 19 files changed, 1312 insertions(+), 223 deletions(-) create mode 100644 scripts/fix-package/@xenova/transformers/transformers.js create mode 100644 scripts/fix-package/@xenova/transformers/utils/image.js create mode 100644 scripts/fix-package/@xenova/transformers/utils/sharp.js create mode 100644 scripts/fix-package/onnxruntime-node/binding.js delete mode 100644 scripts/fix-package/vectordb/native.js create mode 100644 src/extension/ai/embeddings/embedding-manager.ts create mode 100644 src/extension/ai/embeddings/transformer-js-embeddings.ts create mode 100644 src/extension/ai/embeddings/types.ts create mode 100644 src/extension/registers/model-register.ts diff --git a/.vscode/settings.json b/.vscode/settings.json index 5bf2d1e..b8cdd55 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,6 +4,21 @@ "i18n-ally.localesPaths": ["."], "i18n-ally.enabledFrameworks": ["vscode", "react"], "i18n-ally.dirStructure": "file", + + // performance + "editor.largeFileOptimizations": true, + "search.exclude": { + "**/node_modules/**": true, + "**/dist/**": true, + "**/*.min.js": true + }, + "files.watcherExclude": { + "**/.git/objects/**": true, + "**/.git/subtree-cache/**": true, + "**/node_modules/**": true, + "**/dist/**": true + }, + // Enable eslint for all supported languages "eslint.validate": [ "javascript", @@ -18,9 +33,6 @@ ], "editor.tabSize": 2, "editor.detectIndentation": false, - "search.exclude": { - "package-lock.json": true - }, "editor.codeActionsOnSave": [ "source.addMissingImports", "source.fixAll.eslint" @@ -74,17 +86,20 @@ "lancedb", "langchain", "langgraph", + "logits", "Mhchem", "multistream", "Nicepkg", "nodir", "Nolebase", "Ollama", + "onnx", "onnxruntime", "openai", "pino", "Pipfile", "Pluggable", + "pretrained", "pyproject", "qwen", "rehype", @@ -104,6 +119,7 @@ "treeshake", "tsup", "undici", + "unsqueeze", "uuidv", "vectordb", "vectorstores", diff --git a/package.json b/package.json index 4b98833..e95b7be 100644 --- a/package.json +++ b/package.json @@ -449,8 +449,6 @@ "minimatch": "^10.0.1", "next-themes": "^0.3.0", "node-fetch": "^3.3.2", - "onnxruntime-common": "^1.19.2", - "onnxruntime-node": "^1.19.2", "p-limit": "^6.1.0", "pnpm": "^9.10.0", "postcss": "^8.4.47", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d34be1c..ec3e1e4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -167,7 +167,7 @@ importers: version: 10.4.20(postcss@8.4.47) babel-plugin-react-compiler: specifier: latest - version: 0.0.0-experimental-6067d4e-20240919 + version: 0.0.0-experimental-6067d4e-20240924 chalk: specifier: ^5.3.0 version: 5.3.0 @@ -233,7 +233,7 @@ importers: version: 5.2.1(@types/eslint@8.56.10)(eslint-config-prettier@9.1.0(eslint@8.57.0))(eslint@8.57.0)(prettier@3.3.3) eslint-plugin-react-compiler: specifier: latest - version: 0.0.0-experimental-92aaa43-20240919(eslint@8.57.0) + version: 0.0.0-experimental-92aaa43-20240924(eslint@8.57.0) eslint-plugin-simple-import-sort: specifier: ^12.1.1 version: 12.1.1(eslint@8.57.0) @@ -306,12 +306,6 @@ importers: node-fetch: specifier: ^3.3.2 version: 3.3.2 - onnxruntime-common: - specifier: ^1.19.2 - version: 1.19.2 - onnxruntime-node: - specifier: ^1.19.2 - version: 1.19.2 p-limit: specifier: ^6.1.0 version: 6.1.0 @@ -1658,10 +1652,6 @@ packages: resolution: {integrity: sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==} engines: {node: '>=12'} - '@isaacs/fs-minipass@4.0.1': - resolution: {integrity: sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==} - engines: {node: '>=18.0.0'} - '@jest/types@24.9.0': resolution: {integrity: sha512-XKK7ze1apu5JWQ5eZjHITP66AX+QsLlbaJRBGYr8pNzwcAE2JVkwnf0yqjHTsDRcjR0mujy/NmZMXw5kl+kGBw==} engines: {node: '>= 6'} @@ -4003,8 +3993,8 @@ packages: b4a@1.6.6: resolution: {integrity: sha512-5Tk1HLk6b6ctmjIkAcU/Ujv/1WqiDl0F0JdRCR80VsOcUlHcu7pWeWRlOqQLHfDEsVx9YH/aif5AG4ehoCtTmg==} - babel-plugin-react-compiler@0.0.0-experimental-6067d4e-20240919: - resolution: {integrity: sha512-3BHXXnd3GzOkHHWMhYLARTUa03PyMzhbAA3ptG+WXujJu0mx1BT3CslcqDlKMh7j508uspT5JCXRZh0ZIN9a0g==} + babel-plugin-react-compiler@0.0.0-experimental-6067d4e-20240924: + resolution: {integrity: sha512-Xprt5PqHZKqF2H8Di7y+o9j1RTFsNGJ6ntBcRFu8kcChy5sVSVVIKXq+FBezcBhVChzaRrUb+OV/nWZlJH1aJA==} bail@2.0.2: resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} @@ -4188,10 +4178,6 @@ packages: chownr@1.1.4: resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} - chownr@3.0.0: - resolution: {integrity: sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==} - engines: {node: '>=18'} - class-variance-authority@0.7.0: resolution: {integrity: sha512-jFI8IQw4hczaL4ALINxqLEXQbWcNjoSkloa4IaufXCJr6QawJyw7tuRysRsrE8w2p/4gGaxKIt/hX3qz/IbD1A==} @@ -5093,8 +5079,8 @@ packages: eslint-config-prettier: optional: true - eslint-plugin-react-compiler@0.0.0-experimental-92aaa43-20240919: - resolution: {integrity: sha512-l1tEUmxnZcMNkpUffbyNPAV91kZMZYLVeCCRZajK/s1QdP9FquunJQ9uT4c3f0RqdV6n0kloVfnJ0lPD1FTPlg==} + eslint-plugin-react-compiler@0.0.0-experimental-92aaa43-20240924: + resolution: {integrity: sha512-rPzsC8nMLCg8ol41MfW6Un9HdDUW9BFBbhtXdhwLjOlIZ8ffeNdRKhwEa34eZCkJDH/fTO5X4graW0ndj+ni7A==} engines: {node: ^14.17.0 || ^16.0.0 || >= 18.0.0} peerDependencies: eslint: '>=7' @@ -6620,21 +6606,12 @@ packages: minisearch@7.0.1: resolution: {integrity: sha512-xLeX/AwTJLzgBF2/bdUI7MEePwXtzaLExkRwu8YFGfLDwSe06KYkplqPodLANsqvfc5Ks/r5ItFUSjIp7+9xtw==} - minizlib@3.0.1: - resolution: {integrity: sha512-umcy022ILvb5/3Djuu8LWeqUa8D68JaBzlttKeMWen48SjabqS3iY5w/vzeMzMUNhLDifyhbOwKDSznB1vvrwg==} - engines: {node: '>= 18'} - mitt@3.0.1: resolution: {integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==} mkdirp-classic@0.5.3: resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} - mkdirp@3.0.1: - resolution: {integrity: sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==} - engines: {node: '>=10'} - hasBin: true - mlly@1.7.1: resolution: {integrity: sha512-rrVRZRELyQzrIUAVMHxP97kv+G786pHmOKzuFII8zDYahFBS7qnHh2AlYSl1GAHhaMPCz6/oHjVMcfFYgFYHgA==} @@ -6828,17 +6805,10 @@ packages: onnxruntime-common@1.14.0: resolution: {integrity: sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==} - onnxruntime-common@1.19.2: - resolution: {integrity: sha512-a4R7wYEVFbZBlp0BfhpbFWqe4opCor3KM+5Wm22Az3NGDcQMiU2hfG/0MfnBs+1ZrlSGmlgWeMcXQkDk1UFb8Q==} - onnxruntime-node@1.14.0: resolution: {integrity: sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==} os: [win32, darwin, linux] - onnxruntime-node@1.19.2: - resolution: {integrity: sha512-9eHMP/HKbbeUcqte1JYzaaRC8JPn7ojWeCeoyShO86TOR97OCyIyAIOGX3V95ErjslVhJRXY8Em/caIUc0hm1Q==} - os: [win32, darwin, linux] - onnxruntime-web@1.14.0: resolution: {integrity: sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==} @@ -7439,10 +7409,6 @@ packages: deprecated: Rimraf versions prior to v4 are no longer supported hasBin: true - rimraf@5.0.10: - resolution: {integrity: sha512-l0OE8wL34P4nJH/H2ffoaniAokM2qSmrtXHmlpvYr5AVVX8msAyW0l8NVJFDxlSK4u3Uh/f41cQheDVdnYijwQ==} - hasBin: true - rimraf@6.0.1: resolution: {integrity: sha512-9dkvaxAsk/xNXSJzMgFqqMCuFgt2+KsOFek3TMLfo8NCPfWpBmqwyNn5Y+NX56QUYfCtsyhF3ayiboEoUmJk/A==} engines: {node: 20 || >=22} @@ -7893,10 +7859,6 @@ packages: tar-stream@3.1.7: resolution: {integrity: sha512-qJj60CXt7IU1Ffyc3NJMjh6EkuCFej46zUqJ4J7pqYlThyd9bO0XBTmcOIhSzZJVWfsLks0+nle/j538YAW9RQ==} - tar@7.4.3: - resolution: {integrity: sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==} - engines: {node: '>=18'} - terminal-link@3.0.0: resolution: {integrity: sha512-flFL3m4wuixmf6IfhFJd1YPiLiMuxEc8uHRM1buzIeZPm22Au2pDqBJQgdo7n1WfPU1ONFGv7YDwpFBmHGF6lg==} engines: {node: '>=12'} @@ -8609,10 +8571,6 @@ packages: yallist@4.0.0: resolution: {integrity: sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==} - yallist@5.0.0: - resolution: {integrity: sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==} - engines: {node: '>=18'} - yaml@2.5.0: resolution: {integrity: sha512-2wWLbGbYDiSqqIKoPjar3MPgB94ErzCtrNE1FdqGuaO0pi2JGjmE8aW8TDZwzU7vuxcGRdL/4gPQwQ7hD5AMSw==} engines: {node: '>= 14'} @@ -9924,10 +9882,6 @@ snapshots: wrap-ansi: 8.1.0 wrap-ansi-cjs: wrap-ansi@7.0.0 - '@isaacs/fs-minipass@4.0.1': - dependencies: - minipass: 7.1.2 - '@jest/types@24.9.0': dependencies: '@types/istanbul-lib-coverage': 2.0.6 @@ -12235,7 +12189,7 @@ snapshots: b4a@1.6.6: {} - babel-plugin-react-compiler@0.0.0-experimental-6067d4e-20240919: + babel-plugin-react-compiler@0.0.0-experimental-6067d4e-20240924: dependencies: '@babel/generator': 7.2.0 '@babel/types': 7.25.2 @@ -12451,8 +12405,6 @@ snapshots: chownr@1.1.4: {} - chownr@3.0.0: {} - class-variance-authority@0.7.0: dependencies: clsx: 2.0.0 @@ -13569,7 +13521,7 @@ snapshots: '@types/eslint': 8.56.10 eslint-config-prettier: 9.1.0(eslint@8.57.0) - eslint-plugin-react-compiler@0.0.0-experimental-92aaa43-20240919(eslint@8.57.0): + eslint-plugin-react-compiler@0.0.0-experimental-92aaa43-20240924(eslint@8.57.0): dependencies: '@babel/core': 7.25.2 '@babel/parser': 7.25.3 @@ -15433,17 +15385,10 @@ snapshots: minisearch@7.0.1: {} - minizlib@3.0.1: - dependencies: - minipass: 7.1.2 - rimraf: 5.0.10 - mitt@3.0.1: {} mkdirp-classic@0.5.3: {} - mkdirp@3.0.1: {} - mlly@1.7.1: dependencies: acorn: 8.11.3 @@ -15636,18 +15581,11 @@ snapshots: onnxruntime-common@1.14.0: {} - onnxruntime-common@1.19.2: {} - onnxruntime-node@1.14.0: dependencies: onnxruntime-common: 1.14.0 optional: true - onnxruntime-node@1.19.2: - dependencies: - onnxruntime-common: 1.19.2 - tar: 7.4.3 - onnxruntime-web@1.14.0: dependencies: flatbuffers: 1.12.0 @@ -16306,10 +16244,6 @@ snapshots: dependencies: glob: 7.2.3 - rimraf@5.0.10: - dependencies: - glob: 10.4.3 - rimraf@6.0.1: dependencies: glob: 11.0.0 @@ -16882,15 +16816,6 @@ snapshots: fast-fifo: 1.3.2 streamx: 2.20.1 - tar@7.4.3: - dependencies: - '@isaacs/fs-minipass': 4.0.1 - chownr: 3.0.0 - minipass: 7.1.2 - minizlib: 3.0.1 - mkdirp: 3.0.1 - yallist: 5.0.0 - terminal-link@3.0.0: dependencies: ansi-escapes: 5.0.0 @@ -17647,8 +17572,6 @@ snapshots: yallist@4.0.0: {} - yallist@5.0.0: {} - yaml@2.5.0: {} yargs-parser@21.1.1: {} diff --git a/scripts/fix-package/@xenova/transformers/transformers.js b/scripts/fix-package/@xenova/transformers/transformers.js new file mode 100644 index 0000000..5f2ed2b --- /dev/null +++ b/scripts/fix-package/@xenova/transformers/transformers.js @@ -0,0 +1,24 @@ +/** + * @file Entry point for the Transformers.js library. Only the exports from this file + * are available to the end user, and are grouped as follows: + * + * 1. [Pipelines](./pipelines) + * 2. [Environment variables](./env) + * 3. [Models](./models) + * 4. [Tokenizers](./tokenizers) + * 5. [Processors](./processors) + * + * @module transformers + */ + +export * from './pipelines.js' +export * from './env.js' +// export * from './models.js'; +// export * from './tokenizers.js'; +// export * from './processors.js'; +// export * from './configs.js'; + +// export * from './utils/audio.js'; +// export * from './utils/image.js'; +// export * from './utils/tensor.js'; +export * from './utils/maths.js' diff --git a/scripts/fix-package/@xenova/transformers/utils/image.js b/scripts/fix-package/@xenova/transformers/utils/image.js new file mode 100644 index 0000000..67f71be --- /dev/null +++ b/scripts/fix-package/@xenova/transformers/utils/image.js @@ -0,0 +1,811 @@ +/** + * @file Helper module for image processing. + * + * These functions and classes are only used internally, + * meaning an end-user shouldn't need to access anything here. + * + * @module utils/image + */ + +import { env } from '../env.js' +import { getFile } from './hub.js' +// Will be empty (or not used) if running in browser or web-worker +import sharp from './sharp.js' +import { Tensor } from './tensor.js' + +const BROWSER_ENV = typeof self !== 'undefined' +const WEBWORKER_ENV = + BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope' + +let createCanvasFunction +let ImageDataClass +let loadImageFunction +if (BROWSER_ENV) { + // Running in browser or web-worker + createCanvasFunction = ( + /** @type {number} */ width, + /** @type {number} */ height + ) => { + if (!self.OffscreenCanvas) { + throw new Error('OffscreenCanvas not supported by this browser.') + } + return new self.OffscreenCanvas(width, height) + } + loadImageFunction = self.createImageBitmap + ImageDataClass = self.ImageData +} else if (sharp) { + // Running in Node.js, electron, or other non-browser environment + + loadImageFunction = async (/**@type {sharp.Sharp}*/ img) => { + const metadata = await img.metadata() + const rawChannels = metadata.channels + + let { data, info } = await img + .rotate() + .raw() + .toBuffer({ resolveWithObject: true }) + + const newImage = new RawImage( + new Uint8ClampedArray(data), + info.width, + info.height, + info.channels + ) + if (rawChannels !== undefined && rawChannels !== info.channels) { + // Make sure the new image has the same number of channels as the input image. + // This is necessary for grayscale images. + newImage.convert(rawChannels) + } + return newImage + } +} else { + throw new Error('Unable to load image processing library.') +} + +// Defined here: https://github.com/python-pillow/Pillow/blob/a405e8406b83f8bfb8916e93971edc7407b8b1ff/src/libImaging/Imaging.h#L262-L268 +const RESAMPLING_MAPPING = { + 0: 'nearest', + 1: 'lanczos', + 2: 'bilinear', + 3: 'bicubic', + 4: 'box', + 5: 'hamming' +} + +/** + * Mapping from file extensions to MIME types. + */ +const CONTENT_TYPE_MAP = new Map([ + ['png', 'image/png'], + ['jpg', 'image/jpeg'], + ['jpeg', 'image/jpeg'], + ['gif', 'image/gif'] +]) + +export class RawImage { + /** + * Create a new `RawImage` object. + * @param {Uint8ClampedArray|Uint8Array} data The pixel data. + * @param {number} width The width of the image. + * @param {number} height The height of the image. + * @param {1|2|3|4} channels The number of channels. + */ + constructor(data, width, height, channels) { + this.data = data + this.width = width + this.height = height + this.channels = channels + } + + /** + * Returns the size of the image (width, height). + * @returns {[number, number]} The size of the image (width, height). + */ + get size() { + return [this.width, this.height] + } + + /** + * Helper method for reading an image from a variety of input types. + * @param {RawImage|string|URL} input + * @returns The image object. + * + * **Example:** Read image from a URL. + * ```javascript + * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); + * // RawImage { + * // "data": Uint8ClampedArray [ 25, 25, 25, 19, 19, 19, ... ], + * // "width": 800, + * // "height": 533, + * // "channels": 3 + * // } + * ``` + */ + static async read(input) { + if (input instanceof RawImage) { + return input + } else if (typeof input === 'string' || input instanceof URL) { + return await this.fromURL(input) + } else { + throw new Error(`Unsupported input type: ${typeof input}`) + } + } + + /** + * Read an image from a URL or file path. + * @param {string|URL} url The URL or file path to read the image from. + * @returns {Promise} The image object. + */ + static async fromURL(url) { + let response = await getFile(url) + if (response.status !== 200) { + throw new Error( + `Unable to read image from "${url}" (${response.status} ${response.statusText})` + ) + } + let blob = await response.blob() + return this.fromBlob(blob) + } + + /** + * Helper method to create a new Image from a blob. + * @param {Blob} blob The blob to read the image from. + * @returns {Promise} The image object. + */ + static async fromBlob(blob) { + if (BROWSER_ENV) { + // Running in environment with canvas + let img = await loadImageFunction(blob) + + const ctx = createCanvasFunction(img.width, img.height).getContext('2d') + + // Draw image to context + ctx.drawImage(img, 0, 0) + + return new this( + ctx.getImageData(0, 0, img.width, img.height).data, + img.width, + img.height, + 4 + ) + } else { + // Use sharp.js to read (and possible resize) the image. + let img = sharp(await blob.arrayBuffer()) + + return await loadImageFunction(img) + } + } + + /** + * Helper method to create a new Image from a tensor + * @param {Tensor} tensor + */ + static fromTensor(tensor, channel_format = 'CHW') { + if (tensor.dims.length !== 3) { + throw new Error( + `Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.` + ) + } + + if (channel_format === 'CHW') { + tensor = tensor.transpose(1, 2, 0) + } else if (channel_format === 'HWC') { + // Do nothing + } else { + throw new Error(`Unsupported channel format: ${channel_format}`) + } + if ( + !( + tensor.data instanceof Uint8ClampedArray || + tensor.data instanceof Uint8Array + ) + ) { + throw new Error(`Unsupported tensor type: ${tensor.type}`) + } + switch (tensor.dims[2]) { + case 1: + case 2: + case 3: + case 4: + return new RawImage( + tensor.data, + tensor.dims[1], + tensor.dims[0], + tensor.dims[2] + ) + default: + throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`) + } + } + + /** + * Convert the image to grayscale format. + * @returns {RawImage} `this` to support chaining. + */ + grayscale() { + if (this.channels === 1) { + return this + } + + let newData = new Uint8ClampedArray(this.width * this.height * 1) + switch (this.channels) { + case 3: // rgb to grayscale + case 4: // rgba to grayscale + for (let i = 0, offset = 0; i < this.data.length; i += this.channels) { + const red = this.data[i] + const green = this.data[i + 1] + const blue = this.data[i + 2] + + newData[offset++] = Math.round( + 0.2989 * red + 0.587 * green + 0.114 * blue + ) + } + break + default: + throw new Error( + `Conversion failed due to unsupported number of channels: ${this.channels}` + ) + } + return this._update(newData, this.width, this.height, 1) + } + + /** + * Convert the image to RGB format. + * @returns {RawImage} `this` to support chaining. + */ + rgb() { + if (this.channels === 3) { + return this + } + + let newData = new Uint8ClampedArray(this.width * this.height * 3) + + switch (this.channels) { + case 1: // grayscale to rgb + for (let i = 0, offset = 0; i < this.data.length; ++i) { + newData[offset++] = this.data[i] + newData[offset++] = this.data[i] + newData[offset++] = this.data[i] + } + break + case 4: // rgba to rgb + for (let i = 0, offset = 0; i < this.data.length; i += 4) { + newData[offset++] = this.data[i] + newData[offset++] = this.data[i + 1] + newData[offset++] = this.data[i + 2] + } + break + default: + throw new Error( + `Conversion failed due to unsupported number of channels: ${this.channels}` + ) + } + return this._update(newData, this.width, this.height, 3) + } + + /** + * Convert the image to RGBA format. + * @returns {RawImage} `this` to support chaining. + */ + rgba() { + if (this.channels === 4) { + return this + } + + let newData = new Uint8ClampedArray(this.width * this.height * 4) + + switch (this.channels) { + case 1: // grayscale to rgba + for (let i = 0, offset = 0; i < this.data.length; ++i) { + newData[offset++] = this.data[i] + newData[offset++] = this.data[i] + newData[offset++] = this.data[i] + newData[offset++] = 255 + } + break + case 3: // rgb to rgba + for (let i = 0, offset = 0; i < this.data.length; i += 3) { + newData[offset++] = this.data[i] + newData[offset++] = this.data[i + 1] + newData[offset++] = this.data[i + 2] + newData[offset++] = 255 + } + break + default: + throw new Error( + `Conversion failed due to unsupported number of channels: ${this.channels}` + ) + } + + return this._update(newData, this.width, this.height, 4) + } + + /** + * Resize the image to the given dimensions. This method uses the canvas API to perform the resizing. + * @param {number} width The width of the new image. + * @param {number} height The height of the new image. + * @param {Object} options Additional options for resizing. + * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. + * @returns {Promise} `this` to support chaining. + */ + async resize(width, height, { resample = 2 } = {}) { + // Ensure resample method is a string + let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample + + if (BROWSER_ENV) { + // TODO use `resample` in browser environment + + // Store number of channels before resizing + let numChannels = this.channels + + // Create canvas object for this image + let canvas = this.toCanvas() + + // Actually perform resizing using the canvas API + const ctx = createCanvasFunction(width, height).getContext('2d') + + // Draw image to context, resizing in the process + ctx.drawImage(canvas, 0, 0, width, height) + + // Create image from the resized data + let resizedImage = new RawImage( + ctx.getImageData(0, 0, width, height).data, + width, + height, + 4 + ) + + // Convert back so that image has the same number of channels as before + return resizedImage.convert(numChannels) + } else { + // Create sharp image from raw data, and resize + let img = this.toSharp() + + switch (resampleMethod) { + case 'box': + case 'hamming': + if (resampleMethod === 'box' || resampleMethod === 'hamming') { + console.warn( + `Resampling method ${resampleMethod} is not yet supported. Using bilinear instead.` + ) + resampleMethod = 'bilinear' + } + + case 'nearest': + case 'bilinear': + case 'bicubic': + // Perform resizing using affine transform. + // This matches how the python Pillow library does it. + img = img.affine([width / this.width, 0, 0, height / this.height], { + interpolator: resampleMethod + }) + break + + case 'lanczos': + // https://github.com/python-pillow/Pillow/discussions/5519 + // https://github.com/lovell/sharp/blob/main/docs/api-resize.md + img = img.resize({ + width, + height, + fit: 'fill', + kernel: 'lanczos3' // PIL Lanczos uses a kernel size of 3 + }) + break + + default: + throw new Error( + `Resampling method ${resampleMethod} is not supported.` + ) + } + + return await loadImageFunction(img) + } + } + + async pad([left, right, top, bottom]) { + left = Math.max(left, 0) + right = Math.max(right, 0) + top = Math.max(top, 0) + bottom = Math.max(bottom, 0) + + if (left === 0 && right === 0 && top === 0 && bottom === 0) { + // No padding needed + return this + } + + if (BROWSER_ENV) { + // Store number of channels before padding + let numChannels = this.channels + + // Create canvas object for this image + let canvas = this.toCanvas() + + let newWidth = this.width + left + right + let newHeight = this.height + top + bottom + + // Create a new canvas of the desired size. + const ctx = createCanvasFunction(newWidth, newHeight).getContext('2d') + + // Draw image to context, padding in the process + ctx.drawImage( + canvas, + 0, + 0, + this.width, + this.height, + left, + top, + newWidth, + newHeight + ) + + // Create image from the padded data + let paddedImage = new RawImage( + ctx.getImageData(0, 0, newWidth, newHeight).data, + newWidth, + newHeight, + 4 + ) + + // Convert back so that image has the same number of channels as before + return paddedImage.convert(numChannels) + } else { + let img = this.toSharp().extend({ left, right, top, bottom }) + return await loadImageFunction(img) + } + } + + async crop([x_min, y_min, x_max, y_max]) { + // Ensure crop bounds are within the image + x_min = Math.max(x_min, 0) + y_min = Math.max(y_min, 0) + x_max = Math.min(x_max, this.width - 1) + y_max = Math.min(y_max, this.height - 1) + + // Do nothing if the crop is the entire image + if ( + x_min === 0 && + y_min === 0 && + x_max === this.width - 1 && + y_max === this.height - 1 + ) { + return this + } + + const crop_width = x_max - x_min + 1 + const crop_height = y_max - y_min + 1 + + if (BROWSER_ENV) { + // Store number of channels before resizing + const numChannels = this.channels + + // Create canvas object for this image + const canvas = this.toCanvas() + + // Create a new canvas of the desired size. This is needed since if the + // image is too small, we need to pad it with black pixels. + const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d') + + // Draw image to context, cropping in the process + ctx.drawImage( + canvas, + x_min, + y_min, + crop_width, + crop_height, + 0, + 0, + crop_width, + crop_height + ) + + // Create image from the resized data + const resizedImage = new RawImage( + ctx.getImageData(0, 0, crop_width, crop_height).data, + crop_width, + crop_height, + 4 + ) + + // Convert back so that image has the same number of channels as before + return resizedImage.convert(numChannels) + } else { + // Create sharp image from raw data + const img = this.toSharp().extract({ + left: x_min, + top: y_min, + width: crop_width, + height: crop_height + }) + + return await loadImageFunction(img) + } + } + + async center_crop(crop_width, crop_height) { + // If the image is already the desired size, return it + if (this.width === crop_width && this.height === crop_height) { + return this + } + + // Determine bounds of the image in the new canvas + let width_offset = (this.width - crop_width) / 2 + let height_offset = (this.height - crop_height) / 2 + + if (BROWSER_ENV) { + // Store number of channels before resizing + let numChannels = this.channels + + // Create canvas object for this image + let canvas = this.toCanvas() + + // Create a new canvas of the desired size. This is needed since if the + // image is too small, we need to pad it with black pixels. + const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d') + + let sourceX = 0 + let sourceY = 0 + let destX = 0 + let destY = 0 + + if (width_offset >= 0) { + sourceX = width_offset + } else { + destX = -width_offset + } + + if (height_offset >= 0) { + sourceY = height_offset + } else { + destY = -height_offset + } + + // Draw image to context, cropping in the process + ctx.drawImage( + canvas, + sourceX, + sourceY, + crop_width, + crop_height, + destX, + destY, + crop_width, + crop_height + ) + + // Create image from the resized data + let resizedImage = new RawImage( + ctx.getImageData(0, 0, crop_width, crop_height).data, + crop_width, + crop_height, + 4 + ) + + // Convert back so that image has the same number of channels as before + return resizedImage.convert(numChannels) + } else { + // Create sharp image from raw data + let img = this.toSharp() + + if (width_offset >= 0 && height_offset >= 0) { + // Cropped image lies entirely within the original image + img = img.extract({ + left: Math.floor(width_offset), + top: Math.floor(height_offset), + width: crop_width, + height: crop_height + }) + } else if (width_offset <= 0 && height_offset <= 0) { + // Cropped image lies entirely outside the original image, + // so we add padding + let top = Math.floor(-height_offset) + let left = Math.floor(-width_offset) + img = img.extend({ + top: top, + left: left, + + // Ensures the resulting image has the desired dimensions + right: crop_width - this.width - left, + bottom: crop_height - this.height - top + }) + } else { + // Cropped image lies partially outside the original image. + // We first pad, then crop. + + let y_padding = [0, 0] + let y_extract = 0 + if (height_offset < 0) { + y_padding[0] = Math.floor(-height_offset) + y_padding[1] = crop_height - this.height - y_padding[0] + } else { + y_extract = Math.floor(height_offset) + } + + let x_padding = [0, 0] + let x_extract = 0 + if (width_offset < 0) { + x_padding[0] = Math.floor(-width_offset) + x_padding[1] = crop_width - this.width - x_padding[0] + } else { + x_extract = Math.floor(width_offset) + } + + img = img + .extend({ + top: y_padding[0], + bottom: y_padding[1], + left: x_padding[0], + right: x_padding[1] + }) + .extract({ + left: x_extract, + top: y_extract, + width: crop_width, + height: crop_height + }) + } + + return await loadImageFunction(img) + } + } + + async toBlob(type = 'image/png', quality = 1) { + if (!BROWSER_ENV) { + throw new Error('toBlob() is only supported in browser environments.') + } + + const canvas = this.toCanvas() + return await canvas.convertToBlob({ type, quality }) + } + + toTensor(channel_format = 'CHW') { + let tensor = new Tensor('uint8', new Uint8Array(this.data), [ + this.height, + this.width, + this.channels + ]) + + if (channel_format === 'HWC') { + // Do nothing + } else if (channel_format === 'CHW') { + // hwc -> chw + tensor = tensor.permute(2, 0, 1) + } else { + throw new Error(`Unsupported channel format: ${channel_format}`) + } + return tensor + } + + toCanvas() { + if (!BROWSER_ENV) { + throw new Error('toCanvas() is only supported in browser environments.') + } + + // Clone, and convert data to RGBA before drawing to canvas. + // This is because the canvas API only supports RGBA + let cloned = this.clone().rgba() + + // Create canvas object for the cloned image + let clonedCanvas = createCanvasFunction(cloned.width, cloned.height) + + // Draw image to context + let data = new ImageDataClass(cloned.data, cloned.width, cloned.height) + clonedCanvas.getContext('2d').putImageData(data, 0, 0) + + return clonedCanvas + } + + /** + * Helper method to update the image data. + * @param {Uint8ClampedArray} data The new image data. + * @param {number} width The new width of the image. + * @param {number} height The new height of the image. + * @param {1|2|3|4|null} [channels] The new number of channels of the image. + * @private + */ + _update(data, width, height, channels = null) { + this.data = data + this.width = width + this.height = height + if (channels !== null) { + this.channels = channels + } + return this + } + + /** + * Clone the image + * @returns {RawImage} The cloned image + */ + clone() { + return new RawImage( + this.data.slice(), + this.width, + this.height, + this.channels + ) + } + + /** + * Helper method for converting image to have a certain number of channels + * @param {number} numChannels The number of channels. Must be 1, 3, or 4. + * @returns {RawImage} `this` to support chaining. + */ + convert(numChannels) { + if (this.channels === numChannels) return this // Already correct number of channels + + switch (numChannels) { + case 1: + this.grayscale() + break + case 3: + this.rgb() + break + case 4: + this.rgba() + break + default: + throw new Error( + `Conversion failed due to unsupported number of channels: ${this.channels}` + ) + } + return this + } + + /** + * Save the image to the given path. + * @param {string} path The path to save the image to. + */ + async save(path) { + if (BROWSER_ENV) { + if (WEBWORKER_ENV) { + throw new Error('Unable to save an image from a Web Worker.') + } + + const extension = path.split('.').pop().toLowerCase() + const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png' + + // Convert image to Blob + const blob = await this.toBlob(mime) + + // Convert the canvas content to a data URL + const dataURL = URL.createObjectURL(blob) + + // Create an anchor element with the data URL as the href attribute + const downloadLink = document.createElement('a') + downloadLink.href = dataURL + + // Set the download attribute to specify the desired filename for the downloaded image + downloadLink.download = path + + // Trigger the download + downloadLink.click() + + // Clean up: remove the anchor element from the DOM + downloadLink.remove() + } else if (!env.useFS) { + throw new Error( + 'Unable to save the image because filesystem is disabled in this environment.' + ) + } else { + const img = this.toSharp() + return await img.toFile(path) + } + } + + toSharp() { + if (BROWSER_ENV) { + throw new Error( + 'toSharp() is only supported in server-side environments.' + ) + } + + return sharp(this.data, { + raw: { + width: this.width, + height: this.height, + channels: this.channels + } + }) + } +} diff --git a/scripts/fix-package/@xenova/transformers/utils/sharp.js b/scripts/fix-package/@xenova/transformers/utils/sharp.js new file mode 100644 index 0000000..d2c726f --- /dev/null +++ b/scripts/fix-package/@xenova/transformers/utils/sharp.js @@ -0,0 +1,49 @@ +// mock sharp +class Sharp { + constructor(input) { + this.input = input + this.operations = [] + } + + resize(width, height, options = {}) { + this.operations.push({ type: 'resize', width, height, options }) + return this + } + + rotate(angle, options = {}) { + this.operations.push({ type: 'rotate', angle, options }) + return this + } + + blur(sigma = 1) { + this.operations.push({ type: 'blur', sigma }) + return this + } + + toBuffer() { + return new Promise(resolve => { + console.log('Simulated image processing:') + console.log('Input:', this.input) + this.operations.forEach(op => { + console.log('Operation:', op.type, op) + }) + resolve(Buffer.from('Simulated image buffer')) + }) + } + + toFile(output) { + return new Promise(resolve => { + console.log('Simulated image processing:') + console.log('Input:', this.input) + this.operations.forEach(op => { + console.log('Operation:', op.type, op) + }) + console.log('Output:', output) + resolve({ format: 'jpeg', width: 100, height: 100, channels: 3 }) + }) + } +} + +export default function (input) { + return new Sharp(input) +} diff --git a/scripts/fix-package/onnxruntime-node/binding.js b/scripts/fix-package/onnxruntime-node/binding.js new file mode 100644 index 0000000..e65fa3c --- /dev/null +++ b/scripts/fix-package/onnxruntime-node/binding.js @@ -0,0 +1,12 @@ +'use strict' +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +Object.defineProperty(exports, '__esModule', { value: true }) +exports.binding = void 0 +// export native binding +exports.binding = + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + require( + `./onnxruntime/bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node` + ) +//# sourceMappingURL=binding.js.map diff --git a/scripts/fix-package/vectordb/native.js b/scripts/fix-package/vectordb/native.js deleted file mode 100644 index 4f12273..0000000 --- a/scripts/fix-package/vectordb/native.js +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2023 Lance Developers. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -const { currentTarget } = require('@neon-rs/load') - -let nativeLib - -try { - nativeLib = require(`@lancedb/vectordb-${currentTarget()}`) -} catch (e) { - throw new Error(`vectordb: failed to load native library. - You may need to run \`npm install @lancedb/vectordb-${currentTarget()}\`. - - If that does not work, please file a bug report at https://github.com/lancedb/lancedb/issues - - Source error: ${e}`) -} - -// Dynamic require for runtime. -module.exports = nativeLib diff --git a/src/extension/ai/embeddings/embedding-manager.ts b/src/extension/ai/embeddings/embedding-manager.ts new file mode 100644 index 0000000..22dab6d --- /dev/null +++ b/src/extension/ai/embeddings/embedding-manager.ts @@ -0,0 +1,116 @@ +import { OpenAIEmbeddings } from '@langchain/openai' + +import { TransformerJsEmbeddings } from './transformer-js-embeddings' +import { BaseEmbeddingModelInfo, BaseEmbeddings } from './types' + +export const embeddingModels = [ + { + type: 'transformer-js', + modelName: 'all-MiniLM-L6-v2', + dimensions: 384, + maxTokens: 512, + EmbeddingClass: TransformerJsEmbeddings, + buildConstructParams: modelInfo => ({ + modelInfo + }) + }, + { + type: 'openai', + modelName: 'text-embedding-ada-002', + dimensions: 1536, + maxTokens: 8191, + EmbeddingClass: OpenAIEmbeddings + } +] as const satisfies BaseEmbeddingModelInfo[] + +export type EmbeddingModelType = (typeof embeddingModels)[number]['type'] + +export class EmbeddingManager { + private static instance: EmbeddingManager + + private embeddings: Map = new Map() + + private activeModelKey: string | null = null + + // eslint-disable-next-line @typescript-eslint/no-empty-function + private constructor() {} + + static getInstance(): EmbeddingManager { + if (!EmbeddingManager.instance) { + EmbeddingManager.instance = new EmbeddingManager() + } + return EmbeddingManager.instance + } + + async getEmbedding( + modelInfo: BaseEmbeddingModelInfo + ): Promise { + const key = modelInfo.modelName + + if (!this.embeddings.has(key)) { + const { EmbeddingClass, buildConstructParams } = modelInfo + const embedding = new EmbeddingClass( + buildConstructParams?.(modelInfo) ?? {} + ) + await embedding.init?.() + this.embeddings.set(key, embedding) + } + + return this.embeddings.get(key)! + } + + async setActiveModel(modelInfo: BaseEmbeddingModelInfo): Promise { + await this.getEmbedding(modelInfo) + this.activeModelKey = modelInfo.modelName + } + + getActiveModelInfo(): BaseEmbeddingModelInfo { + if (!this.activeModelKey) throw new Error('No active embedding model set') + + const activeModel = embeddingModels.find( + model => model.modelName === this.activeModelKey + ) + if (!activeModel) throw new Error('No active embedding model set') + + return activeModel + } + + async getActiveEmbedding(): Promise { + const modelInfo = this.getActiveModelInfo() + return await this.getEmbedding(modelInfo) + } + + async embedDocuments(texts: string[]): Promise { + const embedding = await this.getActiveEmbedding() + return embedding.embedDocuments(texts) + } + + async embedQuery(text: string): Promise { + const embedding = await this.getActiveEmbedding() + return embedding.embedQuery(text) + } + + async embedDocumentsWithModel( + modelInfo: BaseEmbeddingModelInfo, + texts: string[] + ): Promise { + const embedding = await this.getEmbedding(modelInfo) + return embedding.embedDocuments(texts) + } + + async embedQueryWithModel( + modelInfo: BaseEmbeddingModelInfo, + text: string + ): Promise { + const embedding = await this.getEmbedding(modelInfo) + return embedding.embedQuery(text) + } + + dispose() { + for (const embedding of this.embeddings.values()) { + embedding.dispose?.() + } + this.embeddings.clear() + this.activeModelKey = null + } +} diff --git a/src/extension/ai/embeddings/transformer-js-embeddings.ts b/src/extension/ai/embeddings/transformer-js-embeddings.ts new file mode 100644 index 0000000..2f6a751 --- /dev/null +++ b/src/extension/ai/embeddings/transformer-js-embeddings.ts @@ -0,0 +1,87 @@ +import path from 'path' +import { logger } from '@extension/logger' +import type { EmbeddingsParams } from '@langchain/core/embeddings' +import { chunkArray } from '@langchain/core/utils/chunk_array' +import type { + FeatureExtractionPipeline, + PipelineType +} from '@xenova/transformers' +import { env, pipeline } from '@xenova/transformers' + +import { BaseEmbeddingModelInfo, BaseEmbeddings } from './types' + +export class TransformerJsEmbeddings extends BaseEmbeddings { + static buildConstructParams( + modelInfo: BaseEmbeddingModelInfo + ): BaseEmbeddingModelInfo { + return modelInfo + } + + private pipeline: FeatureExtractionPipeline | null = null + + private modelInfo: BaseEmbeddingModelInfo + + constructor( + params: { modelInfo: BaseEmbeddingModelInfo } & EmbeddingsParams + ) { + const { modelInfo, ...callerParams } = params + + super(callerParams) + + this.modelInfo = modelInfo + } + + async init() { + if (!this.pipeline) { + env.allowLocalModels = true + env.allowRemoteModels = false + env.localModelPath = path.join(__EXTENSION_DIST_PATH__, `models`) + + this.pipeline = (await pipeline( + 'feature-extraction' as PipelineType, + this.modelInfo.modelName + )) as FeatureExtractionPipeline + + logger.log('Local embedding provider initialized') + } + } + + async embedDocuments(texts: string[]): Promise { + await this.init() + const batches = chunkArray(texts, this.modelInfo.maxTokens) + + const batchRequests = batches.map(batch => this.embeddingWithRetry(batch)) + const batchResponses = await Promise.all(batchRequests) + + return batchResponses.flat() + } + + async embedQuery(text: string): Promise { + await this.init() + const response = await this.embeddingWithRetry([text]) + return response[0]! + } + + private async embeddingWithRetry(batch: string[]): Promise { + if (!this.pipeline) { + throw new Error('Pipeline not initialized') + } + + return this.caller.call(async () => { + try { + const output = await this.pipeline!(batch, { + pooling: 'mean', + normalize: true + }) + return output.tolist() + } catch (e) { + throw new Error(`Error during embedding: ${(e as Error).message}`) + } + }) + } + + dispose() { + this.pipeline?.dispose() + this.pipeline = null + } +} diff --git a/src/extension/ai/embeddings/types.ts b/src/extension/ai/embeddings/types.ts new file mode 100644 index 0000000..f74d5e3 --- /dev/null +++ b/src/extension/ai/embeddings/types.ts @@ -0,0 +1,26 @@ +import { Embeddings } from '@langchain/core/embeddings' + +type ClassType = (new (...args: any[]) => T) & { + [K in keyof typeof BaseEmbeddings]: (typeof BaseEmbeddings)[K] +} + +export interface BaseEmbeddingModelInfo< + T extends BaseEmbeddings = BaseEmbeddings +> { + type: string + modelName: string + dimensions: number + maxTokens: number + EmbeddingClass: ClassType + buildConstructParams?: (modelInfo: BaseEmbeddingModelInfo) => any +} + +export abstract class BaseEmbeddings extends Embeddings { + async init?(): Promise { + // Default implementation (can be overridden by subclasses) + } + + dispose?(): void { + // Default implementation (can be overridden by subclasses) + } +} diff --git a/src/extension/constants.ts b/src/extension/constants.ts index 7a70079..ccc731f 100644 --- a/src/extension/constants.ts +++ b/src/extension/constants.ts @@ -69,5 +69,3 @@ export const DEFAULT_IGNORE_FILETYPES = [ '**/*.bin' // "**/*.prompt", // can be incredibly confusing for the LLM to have another set of instructions injected into the prompt ] - -export const MAX_EMBEDDING_TOKENS = 512 diff --git a/src/extension/file-utils/paths.ts b/src/extension/file-utils/paths.ts index cc6f5d5..ecbc1ff 100644 --- a/src/extension/file-utils/paths.ts +++ b/src/extension/file-utils/paths.ts @@ -11,6 +11,23 @@ const AIDE_DIR = process.env.AIDE_GLOBAL_DIR ?? path.join(os.homedir(), '.aide') export const getExt = (filePath: string): string => path.extname(filePath).slice(1) +export const getSemanticHashName = ( + forSemantic: string, + forHash?: string +): string => { + const semanticsName = forSemantic.replace(/[^a-zA-Z0-9]/g, '_') + + if (!forHash) return semanticsName.toLowerCase() + + const hashName = crypto + .createHash('md5') + .update(forHash) + .digest('hex') + .substring(0, 8) + + return `${semanticsName}_${hashName}`.toLowerCase() +} + export class AidePaths { private static instance: AidePaths @@ -76,17 +93,7 @@ export class AidePaths { getNamespace = () => { const workspacePath = getWorkspaceFolder().uri.fsPath - const workspaceName = path - .basename(workspacePath) - .replace(/[^a-zA-Z0-9]/g, '_') - - const workspaceFullPathHash = crypto - .createHash('md5') - .update(workspacePath) - .digest('hex') - .substring(0, 8) - - return `${workspaceName}_${workspaceFullPathHash}`.toLowerCase() + return getSemanticHashName(workspacePath, path.basename(workspacePath)) } } diff --git a/src/extension/registers/index.ts b/src/extension/registers/index.ts index bdd62ad..7499046 100644 --- a/src/extension/registers/index.ts +++ b/src/extension/registers/index.ts @@ -2,6 +2,7 @@ import { AideKeyUsageStatusBarRegister } from './aide-key-usage-statusbar-regist import { AutoOpenCorrespondingFilesRegister } from './auto-open-corresponding-files-register' import { BaseRegister } from './base-register' import { CodebaseWatcherRegister } from './codebase-watcher-register' +import { ModelRegister } from './model-register' import { RegisterManager } from './register-manager' import { SystemSetupRegister } from './system-setup-register' import { TmpFileActionRegister } from './tmp-file-action-register' @@ -14,12 +15,11 @@ export const setupRegisters = async (registerManager: RegisterManager) => { AideKeyUsageStatusBarRegister, AutoOpenCorrespondingFilesRegister, WebviewRegister, + ModelRegister, CodebaseWatcherRegister ] satisfies (typeof BaseRegister)[] - const promises = Registers.map(async Register => { + for await (const Register of Registers) { await registerManager.setupRegister(Register) - }) - - await Promise.allSettled(promises) + } } diff --git a/src/extension/registers/model-register.ts b/src/extension/registers/model-register.ts new file mode 100644 index 0000000..618db15 --- /dev/null +++ b/src/extension/registers/model-register.ts @@ -0,0 +1,31 @@ +import { + EmbeddingManager, + embeddingModels +} from '@extension/ai/embeddings/embedding-manager' +import type { CommandManager } from '@extension/commands/command-manager' +import * as vscode from 'vscode' + +import { BaseRegister } from './base-register' +import type { RegisterManager } from './register-manager' + +export class ModelRegister extends BaseRegister { + private embeddingManager: EmbeddingManager + + constructor( + protected context: vscode.ExtensionContext, + protected registerManager: RegisterManager, + protected commandManager: CommandManager + ) { + super(context, registerManager, commandManager) + this.embeddingManager = EmbeddingManager.getInstance() + } + + async register(): Promise { + // Set the default active model + await this.embeddingManager.setActiveModel(embeddingModels[0]) + } + + cleanup(): void { + this.embeddingManager.dispose() + } +} diff --git a/src/extension/registers/register-manager.ts b/src/extension/registers/register-manager.ts index 10e0bf7..5df6bec 100644 --- a/src/extension/registers/register-manager.ts +++ b/src/extension/registers/register-manager.ts @@ -1,4 +1,5 @@ import type { CommandManager } from '@extension/commands/command-manager' +import { logger } from '@extension/logger' import * as vscode from 'vscode' import { BaseRegister } from './base-register' @@ -16,9 +17,17 @@ export class RegisterManager { ...args: ConstructorParameters ) => BaseRegister ): Promise { - const register = new RegisterClass(this.context, this, this.commandManager) - await register.register() - this.registers.push(register) + try { + const register = new RegisterClass( + this.context, + this, + this.commandManager + ) + await register.register() + this.registers.push(register) + } catch (e) { + logger.error('Failed to setup register', e) + } } async cleanup(): Promise { diff --git a/src/extension/webview-api/chat-context-processor/vectordb/code-chunks-index-table.ts b/src/extension/webview-api/chat-context-processor/vectordb/code-chunks-index-table.ts index 278887b..effdfb9 100644 --- a/src/extension/webview-api/chat-context-processor/vectordb/code-chunks-index-table.ts +++ b/src/extension/webview-api/chat-context-processor/vectordb/code-chunks-index-table.ts @@ -1,4 +1,5 @@ -import { aidePaths } from '@extension/file-utils/paths' +import { EmbeddingManager } from '@extension/ai/embeddings/embedding-manager' +import { aidePaths, getSemanticHashName } from '@extension/file-utils/paths' import { logger } from '@extension/logger' import { Field, @@ -15,25 +16,11 @@ import { CodeChunkRow } from './types' export class CodeChunksIndexTable { private lanceDb!: Connection - private schema = new Schema([ - new Field('relativePath', new Utf8()), - new Field('fullPath', new Utf8()), - new Field('fileHash', new Utf8()), - new Field('startLine', new Int32()), - new Field('startCharacter', new Int32()), - new Field('endLine', new Int32()), - new Field('endCharacter', new Int32()), - new Field( - 'embedding', - new FixedSizeList(384, new Field('emb', new Float32())) - ) - ]) - async initialize() { try { const lanceDbDir = aidePaths.getLanceDbPath() this.lanceDb = await connect(lanceDbDir) - logger.log('LanceDB initialized successfully') + logger.log('LanceDB initialized successfully', { lanceDbDir }) } catch (error) { logger.error('Failed to initialize LanceDB:', error) throw error @@ -41,15 +28,35 @@ export class CodeChunksIndexTable { } async getOrCreateTable(): Promise> { - const tableName = 'code_chunks_embeddings' + const { modelName, dimensions } = + EmbeddingManager.getInstance().getActiveModelInfo() + const semanticModelName = getSemanticHashName(modelName) + const tableName = `code_chunks_embeddings_${semanticModelName}` + try { const tables = await this.lanceDb.tableNames() - return tables.includes(tableName) - ? await this.lanceDb.openTable(tableName) - : await this.lanceDb.createTable({ - name: tableName, - schema: this.schema - }) + + if (tables.includes(tableName)) + return await this.lanceDb.openTable(tableName) + + const schema = new Schema([ + new Field('relativePath', new Utf8()), + new Field('fullPath', new Utf8()), + new Field('fileHash', new Utf8()), + new Field('startLine', new Int32()), + new Field('startCharacter', new Int32()), + new Field('endLine', new Int32()), + new Field('endCharacter', new Int32()), + new Field( + 'embedding', + new FixedSizeList(dimensions, new Field('emb', new Float32())) + ) + ]) + + return await this.lanceDb.createTable({ + name: tableName, + schema + }) } catch (error) { logger.error('Error getting or creating table:', error) throw error @@ -82,4 +89,9 @@ export class CodeChunksIndexTable { .filter(`fullPath = '${filePath}'`) .execute() } + + async getAllRows(): Promise { + const table = await this.getOrCreateTable() + return await table.filter('true').execute() + } } diff --git a/src/extension/webview-api/chat-context-processor/vectordb/codebase-indexer.ts b/src/extension/webview-api/chat-context-processor/vectordb/codebase-indexer.ts index 0eef58c..cdf672f 100644 --- a/src/extension/webview-api/chat-context-processor/vectordb/codebase-indexer.ts +++ b/src/extension/webview-api/chat-context-processor/vectordb/codebase-indexer.ts @@ -1,12 +1,12 @@ import crypto from 'crypto' -import { MAX_EMBEDDING_TOKENS } from '@extension/constants' +import { EmbeddingManager } from '@extension/ai/embeddings/embedding-manager' +import type { BaseEmbeddings } from '@extension/ai/embeddings/types' import { createShouldIgnore } from '@extension/file-utils/ignore-patterns' import { getExt } from '@extension/file-utils/paths' import { traverseFileOrFolders } from '@extension/file-utils/traverse-fs' import { VsCodeFS } from '@extension/file-utils/vscode-fs' import { logger } from '@extension/logger' import { getWorkspaceFolder } from '@extension/utils' -import { OpenAIEmbeddings } from '@langchain/openai' import { languageIdExts } from '@shared/utils/vscode-lang' import * as vscode from 'vscode' @@ -23,7 +23,7 @@ export class CodebaseIndexer { private progressReporter: ProgressReporter - private embeddings: OpenAIEmbeddings + private embeddings!: BaseEmbeddings private indexingQueue: string[] = [] @@ -34,10 +34,10 @@ export class CodebaseIndexer { workspaceRootPath || getWorkspaceFolder().uri.fsPath this.databaseManager = new CodeChunksIndexTable() this.progressReporter = new ProgressReporter() - this.embeddings = new OpenAIEmbeddings() } async initialize() { + this.embeddings = await EmbeddingManager.getInstance().getActiveEmbedding() await this.databaseManager.initialize() } @@ -171,6 +171,9 @@ export class CodebaseIndexer { try { const rows = await this.createCodeChunkRows(filePath) await this.databaseManager.addRows(rows) + + logger.dev.log(await this.databaseManager.getAllRows()) + logger.log(`Indexed file: ${filePath}`) } catch (error) { logger.error(`Error indexing file ${filePath}:`, error) @@ -187,8 +190,10 @@ export class CodebaseIndexer { const manager = CodeChunkerManager.getInstance() const chunker = await manager.getChunkerFromFilePath(filePath) const content = await VsCodeFS.readFile(filePath) + const { maxTokens } = EmbeddingManager.getInstance().getActiveModelInfo() + const chunks = await chunker.chunkCode(content, { - maxTokenLength: MAX_EMBEDDING_TOKENS, + maxTokenLength: maxTokens, removeDuplicates: true }) @@ -198,9 +203,7 @@ export class CodebaseIndexer { private async createCodeChunkRows(filePath: string): Promise { const chunks = await this.chunkCodeFile(filePath) - const table = await this.databaseManager.getOrCreateTable() - - logger.log('chunks', chunks, await table.filter('true').execute()) + logger.dev.log('code chunks', chunks) const chunkRowsPromises = chunks.map(async chunk => { const embedding = await this.embeddings.embedQuery(chunk.text) diff --git a/vite.config.mts b/vite.config.mts index b03b5ab..e072b07 100644 --- a/vite.config.mts +++ b/vite.config.mts @@ -1,3 +1,5 @@ +/* eslint-disable unused-imports/no-unused-vars */ +/* eslint-disable no-console */ import path, { dirname } from 'path' import { fileURLToPath } from 'url' import vscode from '@tomjs/vite-plugin-vscode' @@ -18,6 +20,9 @@ const resolvePath = (...paths: string[]) => path.resolve(dir, ...paths) const extensionDistPath = resolvePath('dist/extension') +const resolveExtensionDistPath = (...paths: string[]) => + path.resolve(extensionDistPath, ...paths) + const define: Record = { __EXTENSION_DIST_PATH__: JSON.stringify(extensionDistPath) } @@ -49,80 +54,73 @@ export default defineConfig(() => { external: [ './index.node' // shit, vectordb need this ], - async onSuccess() { - await tsupCopyFiles() - } - // esbuildOptions(options) { - // options.alias = { - // 'vectordb/native': resolvePath( - // 'scripts/fix-package/vectordb/native.js' - // ) - // } - // }, - // esbuildPlugins: [ - // { - // name: 'handle-native-node-modules', - // setup(build) { - // build.onResolve({ filter: /\.node$/ }, args => ({ - // path: path.resolve(args.resolveDir, args.path), - // external: true - // })) - // } - // } - // ] + // esbuildPlugins: [createRedirectPlugin(redirects) as any], + esbuildOptions(options) { + options.alias = { + ...options.alias, + 'onnxruntime-node': resolvePath( + 'node_modules/onnxruntime-node/dist/index.js' + ) + } + }, + plugins: [ + { + name: 'copy-files', + async buildStart() { + await tsupCopyFiles() + } + } + ] } }) ], - build: { - commonjsOptions: { - ignoreDynamicRequires: true, - dynamicRequireRoot: '/', - dynamicRequireTargets: ['./bin/napi-v3/**/onnxruntime_binding.node'] - } - }, resolve: { - dedupe: ['react', 'react-dom'], - alias: { - // hack onnxruntime-node - 'onnxruntime-node': path.join( - __dirname, - 'vendors/onnxruntime-node/index.cjs' - ) - } + dedupe: ['react', 'react-dom'] } } }) const tsupCopyFiles = async () => { const targets = [ + // copy node_modules to extension dist + { + src: resolvePath('node_modules/tree-sitter-wasms/out/*.wasm'), + dest: resolveExtensionDistPath('tree-sitter-wasms/') + }, { - src: 'node_modules/tree-sitter-wasms/out/*.wasm', - dest: 'tree-sitter-wasms/' + src: resolvePath('node_modules/web-tree-sitter/*.wasm'), + dest: resolveExtensionDistPath('./') }, { - src: 'node_modules/web-tree-sitter/*.wasm', - dest: './' + src: resolvePath('node_modules/onnxruntime-node/bin/**'), + dest: resolveExtensionDistPath('onnxruntime/bin/') }, { - src: 'node_modules/onnxruntime-node/bin/**', - dest: 'onnxruntime/' + src: resolvePath('node_modules/@lancedb/**'), + dest: resolveExtensionDistPath('node_modules/@lancedb/') }, { - src: 'node_modules/@lancedb/**', - dest: 'node_modules/@lancedb/' + src: resolvePath( + 'src/extension/webview-api/chat-context-processor/models/**' + ), + dest: resolveExtensionDistPath('models/') + }, + + // copy fix-packages to node_modules + { + src: resolvePath('scripts/fix-package/@xenova/transformers/**'), + dest: resolvePath('node_modules/@xenova/transformers/src/') }, { - src: 'src/extension/webview-api/chat-context-processor/models/**', - dest: 'models/' + src: resolvePath('scripts/fix-package/onnxruntime-node/**'), + dest: resolvePath('node_modules/onnxruntime-node/dist/') } ] const promises = targets.map(async ({ src, dest }) => { - const srcPath = resolvePath(src) - const destPath = path.join(extensionDistPath, dest) - - await cpy(srcPath, destPath, { - cwd: dir + await cpy(src, dest, { + cwd: dir, + overwrite: true }) })