diff --git a/changelog.d/20250115_125619_sekachev.bs_updated_mil.md b/changelog.d/20250115_125619_sekachev.bs_updated_mil.md new file mode 100644 index 000000000000..533a158df252 --- /dev/null +++ b/changelog.d/20250115_125619_sekachev.bs_updated_mil.md @@ -0,0 +1,4 @@ +### Changed + +- Enhanced MIL tracker. Optimized memory usage. Now it is runnable on many frames, and applicable to drawn rectangles. + () diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts index b772aca4ca4a..c611bd5cec0a 100644 --- a/cvat-core/src/frames.ts +++ b/cvat-core/src/frames.ts @@ -1,5 +1,5 @@ // Copyright (C) 2021-2022 Intel Corporation -// Copyright (C) 2022-2024 CVAT.ai Corporation +// Copyright (C) 2022-2025 CVAT.ai Corporation // // SPDX-License-Identifier: MIT @@ -300,7 +300,11 @@ export class FrameData { ); } - async data(onServerRequest = () => {}): Promise { + async data(onServerRequest = () => {}): Promise<{ + renderWidth: number; + renderHeight: number; + imageData: ImageBitmap | Blob; + }> { const result = await PluginRegistry.apiWrapper.call(this, FrameData.prototype.data, onServerRequest); return result; } @@ -372,7 +376,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderWidth: number; renderHeight: number; imageData: ImageBitmap | Blob; - } | Blob>((resolve, reject) => { + }>((resolve, reject) => { const requestId = +_.uniqueId(); const requestedDataFrameNumber = meta.getDataFrameNumber(this.number - jobStartFrame); const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber); diff --git a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx index 509dd42b9b35..05723e21e4e1 100644 --- a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx +++ b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 CVAT.ai Corporation +// Copyright (C) 2023-2025 CVAT.ai Corporation // // SPDX-License-Identifier: MIT @@ -37,7 +37,6 @@ const core = getCore(); interface State { actions: BaseAction[]; activeAction: BaseAction | null; - initialized: boolean; fetching: boolean; progress: number | null; progressMessage: string | null; @@ -50,7 +49,6 @@ interface State { } enum ReducerActionType { - SET_INITIALIZED = 'SET_INITIALIZED', SET_ANNOTATIONS_ACTIONS = 'SET_ANNOTATIONS_ACTIONS', SET_ACTIVE_ANNOTATIONS_ACTION = 'SET_ACTIVE_ANNOTATIONS_ACTION', UPDATE_PROGRESS = 'UPDATE_PROGRESS', @@ -65,9 +63,6 @@ enum ReducerActionType { } export const reducerActions = { - setInitialized: (initialized: boolean) => ( - createAction(ReducerActionType.SET_INITIALIZED, { initialized }) - ), setAnnotationsActions: (actions: BaseAction[]) => ( createAction(ReducerActionType.SET_ANNOTATIONS_ACTIONS, { actions }) ), @@ -105,7 +100,6 @@ export const reducerActions = { const defaultState = { actions: [], - initialized: false, fetching: false, activeAction: null, progress: null, @@ -119,23 +113,13 @@ const defaultState = { }; const reducer = (state: State = { ...defaultState }, action: ActionUnion): State => { - if (action.type === ReducerActionType.SET_INITIALIZED) { - return { - ...state, - initialized: action.payload.initialized, - }; - } - if (action.type === ReducerActionType.SET_ANNOTATIONS_ACTIONS) { const { actions } = action.payload; - const { targetObjectState } = state; - const filteredActions = targetObjectState ? actions - .filter((_action) => _action.isApplicableForObject(targetObjectState)) : actions; return { ...state, actions, - activeAction: filteredActions[0] ?? null, + activeAction: state.activeAction ?? actions[0] ?? null, }; } @@ -246,7 +230,6 @@ type ActionParameterProps = NonNullable[keyof BaseActi const componentStorage = createStore(reducer, { actions: [], - initialized: false, fetching: false, activeAction: null, progress: null, @@ -319,15 +302,16 @@ function ActionParameterComponent(props: ActionParameterProps & { onChange: (val interface Props { onClose: () => void; targetObjectState?: ObjectState; + defaultAnnotationAction?: string; } function AnnotationsActionsModalContent(props: Props): JSX.Element { - const { onClose, targetObjectState: defaultTargetObjectState } = props; + const { onClose, targetObjectState: defaultTargetObjectState, defaultAnnotationAction } = props; const dispatch = useDispatch(); const storage = getCVATStore(); const cancellationRef = useRef(false); const { - initialized, actions, activeAction, fetching, targetObjectState, cancelled, + actions, activeAction, fetching, targetObjectState, cancelled, progress, progressMessage, frameFrom, frameTo, actionParameters, modalVisible, } = useSelector((state: State) => ({ ...state }), shallowEqual); @@ -337,20 +321,25 @@ function AnnotationsActionsModalContent(props: Props): JSX.Element { const currentFrameAction = activeAction instanceof BaseCollectionAction || targetObjectState !== null; useEffect(() => { - dispatch(reducerActions.setVisible(true)); - dispatch(reducerActions.updateFrameFrom(jobInstance.startFrame)); - dispatch(reducerActions.updateFrameTo(jobInstance.stopFrame)); - dispatch(reducerActions.updateTargetObjectState(defaultTargetObjectState ?? null)); - }, []); + core.actions.list().then((list: BaseAction[]) => { + dispatch(reducerActions.setAnnotationsActions(list)); + + if (defaultAnnotationAction) { + const defaultAction = list.find((action) => action.name === defaultAnnotationAction); + if ( + defaultAction && + (!defaultTargetObjectState || defaultAction.isApplicableForObject(defaultTargetObjectState)) + ) { + dispatch(reducerActions.setActiveAnnotationsAction(defaultAction)); + } + } - useEffect(() => { - if (!initialized) { - core.actions.list().then((list: BaseAction[]) => { - dispatch(reducerActions.setAnnotationsActions(list)); - dispatch(reducerActions.setInitialized(true)); - }); - } - }, [initialized]); + dispatch(reducerActions.setVisible(true)); + dispatch(reducerActions.updateFrameFrom(jobInstance.startFrame)); + dispatch(reducerActions.updateFrameTo(jobInstance.stopFrame)); + dispatch(reducerActions.updateTargetObjectState(defaultTargetObjectState ?? null)); + }); + }, []); return ( { root.unmount(); div.remove(); diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/opencv-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/opencv-control.tsx index 6a8f3b5fc968..98f7cc5e3ede 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/opencv-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/opencv-control.tsx @@ -1,5 +1,5 @@ // Copyright (C) 2021-2022 Intel Corporation -// Copyright (C) 2023-2024 CVAT.ai Corporation +// Copyright (C) 2023-2025 CVAT.ai Corporation // // SPDX-License-Identifier: MIT @@ -14,12 +14,13 @@ import Button from 'antd/lib/button'; import Progress from 'antd/lib/progress'; import Select from 'antd/lib/select'; import notification from 'antd/lib/notification'; -import message from 'antd/lib/message'; +import Alert from 'antd/lib/alert'; + import { throttle } from 'lodash'; import { OpenCVIcon } from 'icons'; import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper'; -import { getCore, ObjectState } from 'cvat-core-wrapper'; +import { getCore, Job, ObjectState } from 'cvat-core-wrapper'; import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper'; import { IntelligentScissors } from 'utils/opencv-wrapper/intelligent-scissors'; import { @@ -27,27 +28,25 @@ import { } from 'reducers'; import { interactWithCanvas, - fetchAnnotationsAsync, createAnnotationsAsync, - switchNavigationBlocked as switchNavigationBlockedAction, } from 'actions/annotation-actions'; import LabelSelector from 'components/label-selector/label-selector'; import CVATTooltip from 'components/common/cvat-tooltip'; import ApproximationAccuracy, { thresholdFromAccuracy, } from 'components/annotation-page/standard-workspace/controls-side-bar/approximation-accuracy'; -import { OpenCVTracker, TrackerModel } from 'utils/opencv-wrapper/opencv-interfaces'; -import { enableImageFilter as enableImageFilterAction, disableImageFilter as disableImageFilterAction, switchToolsBlockerState } from 'actions/settings-actions'; +import { OpenCVTracker } from 'utils/opencv-wrapper/opencv-interfaces'; +import { enableImageFilter as enableImageFilterAction, disableImageFilter as disableImageFilterAction } from 'actions/settings-actions'; import { ImageFilter, ImageFilterAlias, hasFilter } from 'utils/image-processing'; +import { openAnnotationsActionModal } from 'components/annotation-page/annotations-actions/annotations-actions-modal'; import withVisibilityHandling from './handle-popover-visibility'; interface Props { labels: any[]; canvasInstance: Canvas; canvasReady: boolean; - jobInstance: any; + jobInstance: Job; isActivated: boolean; - states: any[]; frame: number; curZOrder: number; defaultApproxPolyAccuracy: number; @@ -59,31 +58,19 @@ interface Props { interface DispatchToProps { createAnnotations: (states: ObjectState[]) => Promise; - fetchAnnotations: () => Promise; onInteractionStart: typeof interactWithCanvas; - onSwitchToolsBlockerState: typeof switchToolsBlockerState; - switchNavigationBlocked: typeof switchNavigationBlockedAction; enableImageFilter: typeof enableImageFilterAction; disableImageFilter: typeof disableImageFilterAction; } -interface TrackedShape { - clientID: number; - shapePoints: number[]; - trackerModel: TrackerModel; -} - interface State { libraryInitialized: boolean; initializationError: boolean; initializationProgress: number; activeLabelID: number; approxPolyAccuracy: number; - mode: 'interaction' | 'tracking'; - trackedShapes: TrackedShape[]; activeTracker: OpenCVTracker | null; trackers: OpenCVTracker[]; - lastTrackedFrame: number | null; } const core = getCore(); @@ -93,7 +80,6 @@ function mapStateToProps(state: CombinedState): Props { const { annotation: { annotations: { - states, zLayer: { cur: curZOrder }, }, job: { instance: jobInstance, labels }, @@ -114,10 +100,9 @@ function mapStateToProps(state: CombinedState): Props { canvasInstance: canvasInstance as Canvas, canvasReady, defaultApproxPolyAccuracy, - jobInstance, + jobInstance: jobInstance as Job, curZOrder, labels, - states, frame, frameData, toolsBlockerState, @@ -127,10 +112,7 @@ function mapStateToProps(state: CombinedState): Props { const mapDispatchToProps = { onInteractionStart: interactWithCanvas, - fetchAnnotations: fetchAnnotationsAsync, createAnnotations: createAnnotationsAsync, - onSwitchToolsBlockerState: switchToolsBlockerState, - switchNavigationBlocked: switchNavigationBlockedAction, enableImageFilter: enableImageFilterAction, disableImageFilter: disableImageFilterAction, }; @@ -151,11 +133,8 @@ class OpenCVControlComponent extends React.PureComponent shape.trackerModel.delete()); } private interactionListener = async (e: Event): Promise => { - const { mode } = this.state; - - if (mode === 'interaction') { - await this.onInteraction(e); - } - - if (mode === 'tracking') { - await this.onTracking(e); - } + await this.onInteraction(e); }; private onInteraction = async (e: Event): Promise => { @@ -314,106 +281,6 @@ class OpenCVControlComponent extends React.PureComponent => { - const { - isActivated, jobInstance, frame, curZOrder, fetchAnnotations, - } = this.props; - - if (!isActivated) { - return; - } - - const { activeLabelID, trackedShapes, activeTracker } = this.state; - const [label] = jobInstance.labels.filter((_label: any): boolean => _label.id === activeLabelID); - - const { isDone, shapesUpdated } = (e as CustomEvent).detail; - if (!isDone || !shapesUpdated || !activeTracker) { - return; - } - - try { - const { points } = (e as CustomEvent).detail.shapes[0]; - const imageData = this.getCanvasImageData(); - const trackerModel = activeTracker.model(); - trackerModel.init(imageData, points); - const state = new core.classes.ObjectState({ - shapeType: ShapeType.RECTANGLE, - objectType: ObjectType.TRACK, - source: core.enums.Source.SEMI_AUTO, - zOrder: curZOrder, - label, - points, - frame, - occluded: false, - attributes: {}, - descriptions: [`Trackable (${activeTracker.name})`], - }); - const [clientID] = await jobInstance.annotations.put([state]); - this.setState({ - trackedShapes: [ - ...trackedShapes, - { - clientID, - trackerModel, - shapePoints: points, - }, - ], - }); - - // update annotations on a canvas - fetchAnnotations(); - } catch (error: any) { - notification.error({ - description: error.toString(), - message: 'Tracking error occurred', - }); - } - }; - - private getCanvasImageData = ():ImageData => { - const canvas: HTMLCanvasElement | null = window.document.getElementById('cvat_canvas_background') as - | HTMLCanvasElement - | null; - if (!canvas) { - throw new Error('Element #cvat_canvas_background was not found'); - } - const { width, height } = canvas; - const context = canvas.getContext('2d'); - if (!context) { - throw new Error('Canvas context is empty'); - } - return context.getImageData(0, 0, width, height); - }; - - private applyTracking = (imageData: ImageData, shape: TrackedShape, - objectState: any): Promise => new Promise((resolve, reject) => { - setTimeout(() => { - try { - const stateIsRelevant = - objectState.points.length === shape.shapePoints.length && - objectState.points.every( - (coord: number, index: number) => coord === shape.shapePoints[index], - ); - if (!stateIsRelevant) { - shape.trackerModel.reinit(objectState.points); - shape.shapePoints = objectState.points; - } - const { updated, points } = shape.trackerModel.update(imageData); - if (updated) { - objectState.points = points; - objectState.save().then(() => { - shape.shapePoints = points; - }).catch((error: any) => { - reject(error); - }); - } - resolve(); - } catch (error) { - reject(error); - } - }); - }); - private setActiveTracker = (value: string): void => { const { trackers } = this.state; this.setState({ @@ -421,86 +288,6 @@ class OpenCVControlComponent extends React.PureComponent( - (acc: AccumulatorType, trackedShape: TrackedShape): AccumulatorType => { - const [clientState] = objectStates.filter( - (_state: any): boolean => _state.clientID === trackedShape.clientID, - ); - if ( - !clientState || - clientState.keyframes.prev !== frame - 1 || - clientState.keyframes.last >= frame - ) { - return acc; - } - - const { name: trackerName } = trackedShape.trackerModel; - if (!acc[trackerName]) { - acc[trackerName] = []; - } - acc[trackerName].push(trackedShape); - return acc; - }, {}, - ); - - if (Object.keys(trackingData).length === 0) { - return; - } - - try { - switchNavigationBlocked(true); - for (const trackerID of Object.keys(trackingData)) { - const numOfObjects = trackingData[trackerID].length; - const hideMessage = message.loading({ - content: `${trackerID}: ${numOfObjects} ${ - numOfObjects > 1 ? 'objects are' : 'object is' - } being tracked..`, - duration: 0, - className: 'cvat-tracking-notice', - }); - const imageData = this.getCanvasImageData(); - for (const shape of trackingData[trackerID]) { - const [objectState] = objectStates.filter( - (_state: any): boolean => _state.clientID === shape.clientID, - ); - - this.applyTracking(imageData, shape, objectState) - .catch((error) => { - notification.error({ - message: 'Tracking error', - description: error.toString(), - }); - }); - } - setTimeout(() => { - if (hideMessage) hideMessage(); - }); - } - } finally { - setTimeout(() => { - fetchAnnotations(); - switchNavigationBlocked(false); - }); - } - } - } - private async runCVAlgorithm(pressedPoints: number[]): Promise { if (!this.activeTool || pressedPoints.length === 0) { return []; @@ -558,7 +345,6 @@ class OpenCVControlComponent extends React.PureComponent { - this.setState({ mode: 'interaction' }); this.activeTool = openCVWrapper.segmentation.intelligentScissorsFactory(); canvasInstance.cancel(); @@ -612,10 +398,8 @@ class OpenCVControlComponent extends React.PureComponent @@ -630,18 +414,8 @@ class OpenCVControlComponent extends React.PureComponent - - Label - - - - - this.setState({ activeLabelID: value.id })} - /> + + @@ -670,21 +444,12 @@ class OpenCVControlComponent extends React.PureComponent