Skip to content

Commit

Permalink
add ml model to js binding
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Parfenov <[email protected]>
  • Loading branch information
Andrey1994 committed Sep 12, 2023
1 parent 60d5391 commit 06deccb
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nodejs_package/brainflow/board_shim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
} from './brainflow.types';
import {BoardControllerCLikeFunctions as CLike, BoardControllerFunctions} from './functions.types';

class BrainFlowInputParams
export class BrainFlowInputParams
{
private inputParams: IBrainFlowInputParams = {
serialPort: '',
Expand Down
21 changes: 21 additions & 0 deletions nodejs_package/brainflow/brainflow.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ export enum BrainFlowExitCodes {
UNSUPPORTED_CLASSIFIER_AND_METRIC_COMBINATION_ERROR = 23,
}

export enum BrainFlowMetrics {
MINDFULNESS = 0,
RESTFULNESS = 1,
USER_DEFINED = 2,
}

export enum BrainFlowClassifiers {
DEFAULT_CLASSIFIER = 0,
USER_DEFINED = 1,
ONNX_CLASSIFIER = 2,
}

export interface IBrainFlowInputParams {
serialPort: string;
macAddress: string;
Expand All @@ -237,3 +249,12 @@ export interface IBrainFlowInputParams {
fileAnc: string;
masterBoard: BoardIds;
}

export interface IBrainFlowModelParams {
metric: BrainFlowMetrics;
classifier: BrainFlowClassifiers;
file: string;
otherInfo: string;
outputName: string;
maxArraySize: number;
}
32 changes: 32 additions & 0 deletions nodejs_package/brainflow/functions.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,36 @@ export class DataHandlerFunctions
len: number[],
maxLen: number,
) => BrainFlowExitCodes;
}

export enum MLModuleCLikeFunctions {
// logging and version methods
set_log_level_ml_module = 'int set_log_level_ml_module (int log_level)',
set_log_file_ml_module = 'int set_log_file_ml_module (const char *log_file)',
log_message_ml_module = 'int log_message_ml_module (int log_level, char *log_message)',
get_version_ml_module =
'int get_version_ml_module (_Inout_ char *version, _Inout_ int *len, int max)',
prepare = 'int prepare (const char *json_params)',
predict =
'int predict (double *data, int data_len, _Inout_ double *output, _Inout_ int *output_len, const char *json_params)',
release = 'int release (const char *json_params)',
release_all = 'int release_all ()'
}

export class MLModuleFunctions
{
// logging and version methods
setLogLevelMLModule!: (logLevel: LogLevels) => BrainFlowExitCodes;
setLogFileMLModule!: (logFile: string) => BrainFlowExitCodes;
logMessageMLModule!: (logLevel: LogLevels, logMessage: string) => BrainFlowExitCodes;
getVersionMLModule!: (
version: string[],
len: number[],
maxLen: number,
) => BrainFlowExitCodes;
prepare!: (inputJson: string) => BrainFlowExitCodes;
predict!: (data: number[], dataLen: number, output: number[], outputLen: number[],
inputJson: string) => BrainFlowExitCodes;
releaseAll!: () => BrainFlowExitCodes;
release!: (inputJson: string) => BrainFlowExitCodes;
}
3 changes: 2 additions & 1 deletion nodejs_package/brainflow/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from './board_shim';
export * from './brainflow.types';
export * from './data_filter';
export * from './complex';
export * from './complex';
export * from './ml_model';
205 changes: 205 additions & 0 deletions nodejs_package/brainflow/ml_model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import koffi from 'koffi';
import _ from 'lodash';
import * as os from 'os';

import {
BrainFlowClassifiers,
BrainFlowError,
BrainFlowExitCodes,
BrainFlowMetrics,
IBrainFlowModelParams,
LogLevels,
} from './brainflow.types';
import {MLModuleCLikeFunctions as CLike, MLModuleFunctions} from './functions.types';

export class BrainFlowModelParams
{
public inputParams: IBrainFlowModelParams = {
metric: BrainFlowMetrics.USER_DEFINED,
classifier: BrainFlowClassifiers.ONNX_CLASSIFIER,
file: '',
otherInfo: '',
outputName: '',
maxArraySize: 8192
};

constructor(metric: BrainFlowMetrics, classifier: BrainFlowClassifiers,
inputParams: Partial<IBrainFlowModelParams>)
{
this.inputParams = {...this.inputParams, ...inputParams };
this.inputParams.metric = metric;
this.inputParams.classifier = classifier;
}

public toJson(): string
{
const params: Record<string, any> = {};
Object.keys(this.inputParams).forEach((key) => {
params[_.snakeCase(key)] = this.inputParams[key as keyof IBrainFlowModelParams];
});
return JSON.stringify(params);
}
}

class MLModuleDLL extends MLModuleFunctions
{
private static instance: MLModuleDLL;

private libPath: string;
private dllPath: string;
private lib: koffi.IKoffiLib;

private constructor()
{
super ();
this.libPath = `${__dirname}/../brainflow/lib`;
this.dllPath = this.getDLLPath();
this.lib = this.getLib();

this.setLogLevelMLModule = this.lib.func(CLike.set_log_level_ml_module);
this.setLogFileMLModule = this.lib.func(CLike.set_log_file_ml_module);
this.logMessageMLModule = this.lib.func(CLike.log_message_ml_module);
this.getVersionMLModule = this.lib.func(CLike.get_version_ml_module);
this.prepare = this.lib.func(CLike.prepare);
this.predict = this.lib.func(CLike.predict);
this.release = this.lib.func(CLike.release);
this.releaseAll = this.lib.func(CLike.release_all);
}

private getDLLPath()
{
const platform = os.platform();
const arch = os.arch();
switch (platform)
{
case 'darwin':
return `${this.libPath}/libMLModule.dylib`;
case 'win32':
return arch === 'x64' ? `${this.libPath}/MLModule.dll` :
`${this.libPath}/MLModule32.dll`;
case 'linux':
return `${this.libPath}/libMLModule.so`;
default:
throw new BrainFlowError (
BrainFlowExitCodes.GENERAL_ERROR, `OS ${platform} is not supported.`);
}
}

private getLib()
{
try
{
const lib = koffi.load(this.dllPath);
return lib;
}
catch (err)
{
console.error(err);
throw new BrainFlowError (BrainFlowExitCodes.GENERAL_ERROR,
`${'Could not load MLModule DLL - path://'}${this.dllPath}`);
}
}

public static getInstance(): MLModuleDLL
{
if (!MLModuleDLL.instance)
{
MLModuleDLL.instance = new MLModuleDLL ();
}
return MLModuleDLL.instance;
}
}

export class MLModel
{
private inputJson: string;
private input: BrainFlowModelParams;

constructor(metric: BrainFlowMetrics, classifier: BrainFlowClassifiers,
inputParams: Partial<IBrainFlowModelParams>)
{
this.input = new BrainFlowModelParams (metric, classifier, inputParams);
this.inputJson = this.input.toJson();
}

// logging methods
public static getVersion(): string
{
const len = [0];
let out = ['\0'.repeat(512)];
const res = MLModuleDLL.getInstance().getVersionMLModule(out, len, 512);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not get version info');
}
return out[0].substring(0, len[0]);
}

public static setLogLevel(logLevel: LogLevels): void
{
const res = MLModuleDLL.getInstance().setLogLevelMLModule(logLevel);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not set log level properly');
}
}

public static setLogFile(file: string): void
{
const res = MLModuleDLL.getInstance().setLogFileMLModule(file);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not redirect to log file');
}
}

public static logMessage(logLevel: LogLevels, message: string): void
{
const res = MLModuleDLL.getInstance().logMessageMLModule(logLevel, message);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not writte message');
}
}

// model methods
public prepare(): void
{
const res = MLModuleDLL.getInstance().prepare(this.inputJson);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not prepare model');
}
}

public predict(data: number[]): number[]
{
const len = [0];
const output = [...new Array (this.input.inputParams.maxArraySize).fill(0)];
const res =
MLModuleDLL.getInstance().predict(data, data.length, output, len, this.inputJson);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not predict');
}
return output.slice(0, len[0]);
}

public release(): void
{
const res = MLModuleDLL.getInstance().release(this.inputJson);
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not release model');
}
}

public static releaseAll(): void
{
const res = MLModuleDLL.getInstance().releaseAll();
if (res !== BrainFlowExitCodes.STATUS_OK)
{
throw new BrainFlowError (res, 'Could not release models');
}
}
}
32 changes: 32 additions & 0 deletions nodejs_package/brainflow/tests/eeg_metrics.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import {BoardShim} from '../board_shim';
import {BoardIds, BrainFlowClassifiers, BrainFlowMetrics} from '../brainflow.types';
import {DataFilter} from '../data_filter'
import {MLModel} from '../ml_model';

function sleep (ms: number)
{
return new Promise ((resolve) => { setTimeout (resolve, ms); });
}

async function runExample (): Promise<void>
{
const boardId = BoardIds.SYNTHETIC_BOARD;
const board = new BoardShim (boardId, {});
board.prepareSession();
board.startStream();
await sleep (3000);
board.stopStream();
const data = board.getBoardData();
board.releaseSession()
console.info(data);
const eegChannels = BoardShim.getEegChannels(boardId);
const samplingRate = BoardShim.getSamplingRate(boardId);
const bands = DataFilter.getAvgBandPowers(data, eegChannels, samplingRate, true);
const model =
new MLModel (BrainFlowMetrics.RESTFULNESS, BrainFlowClassifiers.DEFAULT_CLASSIFIER, {});
model.prepare();
console.info(model.predict(bands[0]));
model.release();
}

runExample ();
3 changes: 2 additions & 1 deletion nodejs_package/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"transforms": "ts-node brainflow/tests/transforms.ts",
"denoising": "ts-node brainflow/tests/denoising.ts",
"bandpower": "ts-node brainflow/tests/bandpower.ts",
"bandpower_all": "ts-node brainflow/tests/bandpower_all.ts"
"bandpower_all": "ts-node brainflow/tests/bandpower_all.ts",
"eeg_metrics": "ts-node brainflow/tests/eeg_metrics.ts"
},
"repository": {
"type": "git",
Expand Down

0 comments on commit 06deccb

Please sign in to comment.