Skip to content

Commit

Permalink
Added modified isnet that works on webgpu
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielHauschildt committed Mar 25, 2024
1 parent 5b2ac0d commit af09efb
Show file tree
Hide file tree
Showing 22 changed files with 145 additions and 98 deletions.
3 changes: 3 additions & 0 deletions bundle/models/isnet
Git LFS file not shown
3 changes: 3 additions & 0 deletions bundle/models/isnet_fp16
Git LFS file not shown
3 changes: 3 additions & 0 deletions bundle/models/isnet_quint8
Git LFS file not shown
3 changes: 0 additions & 3 deletions bundle/models/large

This file was deleted.

3 changes: 0 additions & 3 deletions bundle/models/medium

This file was deleted.

3 changes: 0 additions & 3 deletions bundle/models/small

This file was deleted.

10 changes: 5 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/node-examples/src/example_001.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async function run() {
);
},
// model: 'small',
model: 'medium',
model: 'isnet',
// model: 'large',
// model: 'modnet',
// model: 'modnet_fp16',
Expand Down
17 changes: 1 addition & 16 deletions packages/node/.resources.mjs
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
export default [
{
path: '/models/',
source: '../../bundle/models/small',
mime: 'application/octet-steam'
},
{
path: '/models/',
source: '../../bundle/models/medium',
mime: 'application/octet-steam'
},
{
path: '/models/',
source: '../../bundle/models/large',
mime: 'application/octet-steam'
},
{
path: '/models/',
source: '../../bundle/models/modnet*',
source: '../../bundle/models/*',
mime: 'application/octet-steam'
}
];
10 changes: 8 additions & 2 deletions packages/node/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.4.5]
## [Unreleased]

### Added

- Added ThirdPartyLicenses.json Added
- Added Modnet models Added

- Added isnet model for webgpu Added

## [1.4.5]

Added ThirdPartyLicenses.json Added

## [1.4.0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ type: Added
# type: Security
# private: true
description: |
Added Modnet models Added
Added Modnet models Added
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
type: Added
# type: Added
# type: Changed
# type: Removed
# type: Security
# private: true
description: |
Added isnet model for webgpu Added
37 changes: 25 additions & 12 deletions packages/node/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ const ConfigSchema = z
.describe('Progress callback.')
.optional(),
model: z
.enum(['small', 'medium', 'large', 'modnet', 'modnet_fp16'])
.preprocess((val) => {
switch (val) {
case 'large':
return 'isnet';
case 'small':
return 'isnet_quint8';
case 'medium':
return 'isnet_fp16';
default:
return val;
}
}, z.enum(['isnet', 'isnet_fp16', 'isnet_quint8', 'modnet', 'modnet_fp16' /*, 'modnet_quint8'*/]))
.default('medium'),
output: z
.object({
Expand All @@ -63,19 +74,21 @@ const ConfigSchema = z
})
.default({})
})
.default({});
.default({})
.transform((config) => {
if (config.debug) console.log('Config:', config);
if (config.debug && !config.progress) {
config.progress =
config.progress ??
((key, current, total) => {
console.debug(`Downloading ${key}: ${current} of ${total}`);
});
}
return config;
});

type Config = z.infer<typeof ConfigSchema>;

function validateConfig(configuration?: Config): Config {
const config = ConfigSchema.parse(configuration ?? {});
if (config.debug) console.log('Config:', config);
if (config.debug && !config.progress) {
config.progress =
config.progress ??
((key, current, total) => {
console.debug(`Downloading ${key}: ${current} of ${total}`);
});
}
return config;
return ConfigSchema.parse(configuration ?? {});
}
17 changes: 1 addition & 16 deletions packages/web-data/.resources.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,7 @@ export default [
},
{
path: '/models/',
source: '../../bundle/models/small',
mime: 'application/octet-steam'
},
{
path: '/models/',
source: '../../bundle/models/medium',
mime: 'application/octet-steam'
},
// {
// path: '/models/',
// source: '../../bundle/models/large',
// mime: 'application/octet-steam'
// },
{
path: '/models/',
source: '../../bundle/models/modnet*',
source: '../../bundle/models/*',
mime: 'application/octet-steam'
}
];
18 changes: 9 additions & 9 deletions packages/web-examples/vite-project/src/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export default {
name: 'App',
setup() {
const images = [
'https://images.unsplash.com/photo-1686002359940-6a51b0d64f68?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=2048&q=80',
'https://images.unsplash.com/photo-1686002359940-6a51b0d64f68?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1024&q=80',
'https://images.unsplash.com/photo-1590523278191-995cbcda646b?ixlib=rb-1.2.1&q=80&fm=jpg&crop=entropy&cs=tinysrgb&w=1080&fit=max&ixid=eyJhcHBfaWQiOjEyMDd9',
'https://images.unsplash.com/photo-1709248835088-03bb0946d6ab?q=80&w=3387&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D'
];
Expand All @@ -27,27 +27,27 @@ export default {
const publicPath = new URL(import.meta.url);
publicPath.pathname = '/js/';
const config = {
// debug: false,
debug: true,
debug: false,
// debug: true,
publicPath: publicPath.href,
progress: (key, current, total) => {
const [type, subtype] = key.split(':');
caption.value = `${type} ${subtype} ${((current / total) * 100).toFixed(
0
)}%`;
},
device: 'gpu',
// device: 'cpu',
// model: 'small',
// model: 'medium',
// model: 'large',
// model: 'isnet',
// model: 'isnet_fp16',
// model: 'isnet_quint8',
// model: 'modnet',
// model: 'modnet_fp16', //# does not work on webgpu
// model: 'modnet_quint8',
output: {
quality: 0.8,
// format: 'image/png'
format: 'image/jpeg'
format: 'image/png'
// format: 'image/jpeg'
// format: 'image/webp'
//format: 'image/x-rgba8'
//format: 'image/x-alpha8'
Expand Down
10 changes: 10 additions & 0 deletions packages/web/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- Added Modnet models Added

- Added option to execute on gpu (webgpu) and cpu Added

- Added isnet model for webgpu Added

## [1.4.5]

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ type: Added
# type: Security
# private: true
description: |
Added Modnet models Added
Added Modnet models Added
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ type: Added
# type: Security
# private: true
description: |
Added option to execute on gpu (webgpu) and cpu Added
Added option to execute on gpu (webgpu) and cpu Added
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
type: Added
# type: Added
# type: Changed
# type: Removed
# type: Security
# private: true
description: |
Added isnet model for webgpu Added
9 changes: 6 additions & 3 deletions packages/web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ async function removeBackground(
): Promise<Blob> {
const { config, session } = await init(configuration);

if (config.progress) config.progress('compute:decode', 0, 4);

const imageTensor = await utils.imageSourceToImageData(image, config);
const [width, height, channels] = imageTensor.shape;

config.progress?.('compute:inference', 1, 4);
const alphamask = await runInference(imageTensor, config, session);
const stride = width * height;

config.progress?.('compute:mask', 2, 4);
const outImageTensor = imageTensor;
for (let i = 0; i < stride; i += 1) {
outImageTensor.data[4 * i + 3] = alphamask.data[i];
}

config.progress?.('compute:encode', 3, 4);
const outImage = await utils.imageEncode(
outImageTensor,
config.output.quality,
config.output.format
);
config.progress?.('compute:encode', 4, 4);

return outImage;
}
Expand Down
4 changes: 0 additions & 4 deletions packages/web/src/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ async function runInference(
config: Config,
session: any
): Promise<NdArray<Uint8Array>> {
if (config.progress) config.progress('compute:inference', 0, 1);
const resolution = 1024;
const [srcHeight, srcWidth, srcChannels] = imageTensor.shape;

let tensorImage = tensorResizeBilinear(imageTensor, resolution, resolution);
const inputTensor = tensorHWCtoBCHW(tensorImage); // this converts also from float to rgba

// run
const predictionsDict = await runOnnxSession(
session,
[['input', inputTensor]],
Expand All @@ -45,6 +42,5 @@ async function runInference(
let alphamaskU8 = convertFloat32ToUint8(alphamask);
alphamaskU8 = tensorResizeBilinear(alphamaskU8, srcWidth, srcHeight);

if (config.progress) config.progress('compute:inference', 1, 1);
return alphamaskU8;
}
Loading

0 comments on commit af09efb

Please sign in to comment.