-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.ts
61 lines (51 loc) · 1.5 KB
/
main.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import { app, BrowserWindow, ipcMain } from 'electron';
import * as path from 'path';
import * as tf from '@tensorflow/tfjs-node';
let mainWindow: BrowserWindow | null;
let model: tf.Sequential;
async function createWindow() {
mainWindow = new BrowserWindow({
width: 800,
height: 600,
webPreferences: {
nodeIntegration: false,
contextIsolation: true,
preload: path.join(__dirname, 'preload.js')
},
});
await mainWindow.loadFile('index.html');
mainWindow.on('closed', () => {
mainWindow = null;
});
}
async function loadModel() {
// Load your TensorFlow.js model here
// For this example, we'll create a simple model
model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Train the model on some data
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
await model.fit(xs, ys, {epochs: 250});
}
app.on('ready', async () => {
await loadModel();
await createWindow();
});
app.on('window-all-closed', () => {
if (process.platform !== 'darwin') {
app.quit();
}
});
app.on('activate', () => {
if (mainWindow === null) {
createWindow();
}
});
ipcMain.handle('runPrediction', async (event, input: number) => {
const inputTensor = tf.tensor2d([input], [1, 1]);
const outputTensor = model.predict(inputTensor) as tf.Tensor;
const outputData = await outputTensor.data();
return outputData[0];
});