diff --git a/label_studio_ml/examples/segment_anything_2_video/Dockerfile b/label_studio_ml/examples/segment_anything_2_video/Dockerfile new file mode 100644 index 000000000..25427f49c --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/Dockerfile @@ -0,0 +1,59 @@ +FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime +ARG DEBIAN_FRONTEND=noninteractive +ARG TEST_ENV + +WORKDIR /app + +RUN conda update conda -y + +RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \ + --mount=type=cache,target="/var/lib/apt/lists",sharing=locked \ + apt-get -y update \ + && apt-get install -y git \ + && apt-get install -y wget \ + && apt-get install -y g++ freeglut3-dev build-essential libx11-dev \ + libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev libfreeimage-dev \ + && apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev python3-pip gcc + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_CACHE_DIR=/.cache \ + PORT=9090 \ + WORKERS=2 \ + THREADS=4 \ + CUDA_HOME=/usr/local/cuda \ + SEGMENT_ANYTHING_2_REPO_PATH=/segment-anything-2 + +RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y +ENV CUDA_HOME=/opt/conda \ + TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0" + +# install base requirements +COPY requirements-base.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements-base.txt + +COPY requirements.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip3 install -r requirements.txt + +# install segment-anything-2 +RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/segment-anything-2.git +WORKDIR /segment-anything-2 +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip3 install -e . +RUN cd checkpoints && ./download_ckpts.sh + +WORKDIR /app + +# install test requirements if needed +COPY requirements-test.txt . +# build only when TEST_ENV="true" +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + if [ "$TEST_ENV" = "true" ]; then \ + pip3 install -r requirements-test.txt; \ + fi + +COPY . ./ + +CMD ["/app/start.sh"] diff --git a/label_studio_ml/examples/segment_anything_2_video/README.md b/label_studio_ml/examples/segment_anything_2_video/README.md new file mode 100644 index 000000000..6c01c273f --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/README.md @@ -0,0 +1,63 @@ +This guide describes the simplest way to start using **SegmentAnything 2** with Label Studio. + +This repository is specifically for working with object tracking in videos. For working with images, +see the [segment_anything_2_image repository](https://github.com/HumanSignal/label-studio-ml-backend/tree/master/label_studio_ml/examples/segment_anything_2_image) + +![sam2](./Sam2Video.gif) + +## Running from source + +1. To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip: + +```bash +git clone https://github.com/HumanSignal/label-studio-ml-backend.git +cd label-studio-ml-backend +pip install -e . +cd label_studio_ml/examples/segment_anything_2_video +pip install -r requirements.txt +``` + +2. Download [`segment-anything-2` repo](https://github.com/facebookresearch/segment-anything-2) into the root directory. Install SegmentAnything model and download checkpoints using [the official Meta documentation](https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#installation). Make sure that you complete the steps for downloadingn the checkpoint files! + +3. Export the following environment variables (fill them in with your credentials!): +- LABEL_STUDIO_URL: the http:// or https:// link to your label studio instance (include the prefix!) +- LABEL_STUDIO_API_KEY: your api key for label studio, available in your profile. + +4. Then you can start the ML backend on the default port `9090`: + +```bash +cd ../ +label-studio-ml start ./segment_anything_2_video +``` +Note that if you're running in a cloud server, you'll need to run on an exposed port. To change the port, add `-p ` to the end of the start command above. +5. Connect running ML backend server to Label Studio: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. Read more in the official [Label Studio documentation](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio). + Again, if you're running in the cloud, you'll need to replace this localhost location with whatever the external ip address is of your container, along with the exposed port. + +# Labeling Config +For your project, you can use any labeling config with video properties. Here's a basic one to get you started! + + + + + + + + + + + + +# Known limitiations +- As of 8/11/2024, SAM2 only runs on GPU servers. +- Currently, we only support the tracking of one object in video, although SAM2 can support multiple. +- Currently, we do not support video segmentation. +- No Docker support + +If you want to contribute to this repository to help with some of these limitations, you can submit a PR. +# Customization + +The ML backend can be customized by adding your own models and logic inside the `./segment_anything_2_video` directory. \ No newline at end of file diff --git a/label_studio_ml/examples/segment_anything_2_video/Sam2Video.gif b/label_studio_ml/examples/segment_anything_2_video/Sam2Video.gif new file mode 100644 index 000000000..46ab19b6e Binary files /dev/null and b/label_studio_ml/examples/segment_anything_2_video/Sam2Video.gif differ diff --git a/label_studio_ml/examples/segment_anything_2_video/_wsgi.py b/label_studio_ml/examples/segment_anything_2_video/_wsgi.py new file mode 100644 index 000000000..957b0dfe9 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/_wsgi.py @@ -0,0 +1,121 @@ +import os +import argparse +import json +import logging +import logging.config + +logging.config.dictConfig({ + "version": 1, + "formatters": { + "standard": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": os.getenv('LOG_LEVEL'), + "stream": "ext://sys.stdout", + "formatter": "standard" + } + }, + "root": { + "level": os.getenv('LOG_LEVEL'), + "handlers": [ + "console" + ], + "propagate": True + } +}) + +from label_studio_ml.api import init_app +from model import NewModel + + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Label studio') + parser.add_argument( + '-p', '--port', dest='port', type=int, default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', + help='Server host') + parser.add_argument( + '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='), + help='Additional LabelStudioMLBase model initialization kwargs') + parser.add_argument( + '-d', '--debug', dest='debug', action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None, + help='Logging level') + parser.add_argument( + '--model-dir', dest='model_dir', default=os.path.dirname(__file__), + help='Directory where models are stored (relative to the project directory)') + parser.add_argument( + '--check', dest='check', action='store_true', + help='Validate model instance before launching server') + parser.add_argument('--basic-auth-user', + default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None), + help='Basic auth user') + + parser.add_argument('--basic-auth-pass', + default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None), + help='Basic auth pass') + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'false': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + NewModel.__name__ + '" instance creation..') + model = NewModel(**kwargs) + + app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app(model_class=NewModel) diff --git a/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml new file mode 100644 index 000000000..f73413a27 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml @@ -0,0 +1,41 @@ +version: "3.8" + +services: + segment_anything_2_video: + container_name: segment_anything_2_video + image: humansignal/segment_anything_2_video:v0 + build: + context: . + args: + TEST_ENV: ${TEST_ENV} + environment: + # specify these parameters if you want to use basic auth for the model server + - BASIC_AUTH_USER= + - BASIC_AUTH_PASS= + # set the log level for the model server + - LOG_LEVEL=DEBUG + # any other parameters that you want to pass to the model server + - ANY=PARAMETER + # specify the number of workers and threads for the model server + - WORKERS=1 + - THREADS=8 + # specify the model directory (likely you don't need to change this) + - MODEL_DIR=/data/models + # specify device + - DEVICE=cuda # or 'cpu' (coming soon) + # SAM2 model config + - MODEL_CONFIG=sam2_hiera_l.yaml + # SAM2 checkpoint + - MODEL_CHECKPOINT=sam2_hiera_large.pt + + # Specify the Label Studio URL and API key to access + # uploaded, local storage and cloud storage files. + # Do not use 'localhost' as it does not work within Docker containers. + # Use prefix 'http://' or 'https://' for the URL always. + # Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows). + - LABEL_STUDIO_URL= + - LABEL_STUDIO_API_KEY= + ports: + - "9090:9090" + volumes: + - "./data/server:/data" diff --git a/label_studio_ml/examples/segment_anything_2_video/model.py b/label_studio_ml/examples/segment_anything_2_video/model.py new file mode 100644 index 000000000..2af5b4046 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/model.py @@ -0,0 +1,336 @@ +import torch +import numpy as np +import os +import pathlib +import cv2 +import tempfile +import logging + +from typing import List, Dict, Optional +from uuid import uuid4 +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.response import ModelResponse +from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path +from label_studio_sdk.label_interface.objects import PredictionValue +from PIL import Image +from sam2.build_sam import build_sam2, build_sam2_video_predictor + +logger = logging.getLogger(__name__) + + +DEVICE = os.getenv('DEVICE', 'cuda') +SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2') +MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2_hiera_l.yaml') +MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2_hiera_large.pt') +MAX_FRAMES_TO_TRACK = int(os.getenv('MAX_FRAMES_TO_TRACK', 10)) + +if DEVICE == 'cuda': + # use bfloat16 for the entire notebook + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + + if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +# build path to the model checkpoint +sam2_checkpoint = str(pathlib.Path(__file__).parent / SEGMENT_ANYTHING_2_REPO_PATH / "checkpoints" / MODEL_CHECKPOINT) +predictor = build_sam2_video_predictor(MODEL_CONFIG, sam2_checkpoint) + + +# manage cache for inference state +# TODO: make it process-safe and implement cache invalidation +_predictor_state_key = '' +_inference_state = None + +def get_inference_state(video_dir): + global _predictor_state_key, _inference_state + if _predictor_state_key != video_dir: + _predictor_state_key = video_dir + _inference_state = predictor.init_state(video_path=video_dir) + return _inference_state + + +class NewModel(LabelStudioMLBase): + """Custom ML Backend model + """ + + def split_frames(self, video_path, temp_dir, start_frame=0, end_frame=100): + # Open the video file + logger.debug(f'Opening video file: {video_path}') + video = cv2.VideoCapture(video_path) + + # check if loaded correctly + if not video.isOpened(): + raise ValueError(f"Could not open video file: {video_path}") + else: + # display number of frames + logger.debug(f'Number of frames: {int(video.get(cv2.CAP_PROP_FRAME_COUNT))}') + + frame_count = 0 + while True: + # Read a frame from the video + success, frame = video.read() + if frame_count < start_frame: + continue + if frame_count + start_frame >= end_frame: + break + + # If frame is read correctly, success is True + if not success: + logger.error(f'Failed to read frame {frame_count}') + break + + # Generate a filename for the frame using the pattern with frame number: '%05d.jpg' + frame_filename = os.path.join(temp_dir, f'{frame_count:05d}.jpg') + if os.path.exists(frame_filename): + logger.debug(f'Frame {frame_count}: {frame_filename} already exists') + yield frame_filename, frame + else: + # Save the frame as an image file + cv2.imwrite(frame_filename, frame) + logger.debug(f'Frame {frame_count}: {frame_filename}') + yield frame_filename, frame + + frame_count += 1 + + # Release the video object + video.release() + + def get_prompts(self, context) -> List[Dict]: + logger.debug(f'Extracting keypoints from context: {context}') + prompts = [] + for ctx in context['result']: + # Process each video tracking object separately + obj_id = ctx['id'] + for obj in ctx['value']['sequence']: + x = obj['x'] / 100 + y = obj['y'] / 100 + box_width = obj['width'] / 100 + box_height = obj['height'] / 100 + frame_idx = obj['frame'] - 1 + + # SAM2 video works with keypoints - convert the rectangle to the set of keypoints within the rectangle + + # bbox (x, y) is top-left corner + kps = [ + # center of the bbox + [x + box_width / 2, y + box_height / 2], + # half of the bbox width to the left + [x + box_width / 4, y + box_height / 2], + # half of the bbox width to the right + [x + 3 * box_width / 4, y + box_height / 2], + # half of the bbox height to the top + [x + box_width / 2, y + box_height / 4], + # half of the bbox height to the bottom + [x + box_width / 2, y + 3 * box_height / 4] + ] + + points = np.array(kps, dtype=np.float32) + labels = np.array([1] * len(kps), dtype=np.int32) + prompts.append({ + 'points': points, + 'labels': labels, + 'frame_idx': frame_idx, + 'obj_id': obj_id + }) + + return prompts + + def _get_fps(self, context): + # get the fps from the context + frames_count = context['result'][0]['value']['framesCount'] + duration = context['result'][0]['value']['duration'] + return frames_count, duration + + # def convert_mask_to_bbox(self, mask): + # # convert mask to bbox + # h, w = mask.shape[-2:] + # mask_int = mask.reshape(h, w, 1).astype(np.uint8) + # contours, _ = cv2.findContours(mask_int, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # if len(contours) == 0: + # return None + # x, y, w, h = cv2.boundingRect(contours[0]) + # return { + # 'x': x, + # 'y': y, + # 'width': w, + # 'height': h + # } + + def convert_mask_to_bbox(self, mask): + # squeeze + mask = mask.squeeze() + + y_indices, x_indices = np.where(mask == 1) + if len(x_indices) == 0 or len(y_indices) == 0: + return None + + # Find the min and max indices + xmin, xmax = np.min(x_indices), np.max(x_indices) + ymin, ymax = np.min(y_indices), np.max(y_indices) + + # Get mask dimensions + height, width = mask.shape + + # Calculate bounding box dimensions + box_width = xmax - xmin + 1 + box_height = ymax - ymin + 1 + + # Normalize and scale to percentage + x_pct = (xmin / width) * 100 + y_pct = (ymin / height) * 100 + width_pct = (box_width / width) * 100 + height_pct = (box_height / height) * 100 + + return { + "x": round(x_pct, 2), + "y": round(y_pct, 2), + "width": round(width_pct, 2), + "height": round(height_pct, 2) + } + + + def dump_image_with_mask(self, frame, mask, output_file, obj_id=None, random_color=False): + from matplotlib import pyplot as plt + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + cmap = plt.get_cmap("tab10") + cmap_idx = 0 if obj_id is None else obj_id + color = np.array([*cmap(cmap_idx)[:3], 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + # create an image file to display image overlayed with mask + mask_image = (mask_image * 255).astype(np.uint8) + mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGRA2BGR) + mask_image = cv2.addWeighted(frame, 1.0, mask_image, 0.8, 0) + logger.debug(f'Shapes: frame={frame.shape}, mask={mask.shape}, mask_image={mask_image.shape}') + # save in file + logger.debug(f'Saving image with mask to {output_file}') + cv2.imwrite(output_file, mask_image) + + + def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: + """ Returns the predicted mask for a smart keypoint that has been placed.""" + + from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video') + + task = tasks[0] + task_id = task['id'] + # Get the video URL from the task + video_url = task['data'][value] + + # cache the video locally + video_path = get_local_path(video_url, task_id=task_id) + logger.debug(f'Video path: {video_path}') + + # get prompts from context + prompts = self.get_prompts(context) + all_obj_ids = set(p['obj_id'] for p in prompts) + # create a map from obj_id to integer + obj_ids = {obj_id: i for i, obj_id in enumerate(all_obj_ids)} + # find the last frame index + first_frame_idx = min(p['frame_idx'] for p in prompts) if prompts else 0 + last_frame_idx = max(p['frame_idx'] for p in prompts) if prompts else 0 + frames_count, duration = self._get_fps(context) + fps = frames_count / duration + + logger.debug( + f'Number of prompts: {len(prompts)}, ' + f'first frame index: {first_frame_idx}, ' + f'last frame index: {last_frame_idx}, ' + f'obj_ids: {obj_ids}') + + frames_to_track = MAX_FRAMES_TO_TRACK + + # Split the video into frames + with tempfile.TemporaryDirectory() as temp_dir: + + # # use persisted dir for debug + # temp_dir = '/tmp/frames' + # os.makedirs(temp_dir, exist_ok=True) + + # get all frames + frames = list(self.split_frames( + video_path, temp_dir, + start_frame=first_frame_idx, + end_frame=last_frame_idx + frames_to_track + 1 + )) + height, width, _ = frames[0][1].shape + logger.debug(f'Video width={width}, height={height}') + + # get inference state + inference_state = get_inference_state(temp_dir) + predictor.reset_state(inference_state) + + for prompt in prompts: + # multiply points by the frame size + prompt['points'][:, 0] *= width + prompt['points'][:, 1] *= height + + _, out_obj_ids, out_mask_logits = predictor.add_new_points( + inference_state=inference_state, + frame_idx=prompt['frame_idx'], + obj_id=obj_ids[prompt['obj_id']], + points=prompt['points'], + labels=prompt['labels'] + ) + + sequence = [] + + debug_dir = './debug-frames' + os.makedirs(debug_dir, exist_ok=True) + + logger.info(f'Propagating in video from frame {last_frame_idx} to {last_frame_idx + frames_to_track}') + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state=inference_state, + start_frame_idx=last_frame_idx, + max_frame_num_to_track=frames_to_track + ): + real_frame_idx = out_frame_idx + first_frame_idx + for i, out_obj_id in enumerate(out_obj_ids): + mask = (out_mask_logits[i] > 0.0).cpu().numpy() + + # to debug, save the mask as an image + # self.dump_image_with_mask(frames[out_frame_idx][1], mask, f'{debug_dir}/{out_frame_idx:05d}_{out_obj_id}.jpg', obj_id=out_obj_id, random_color=True) + + bbox = self.convert_mask_to_bbox(mask) + if bbox: + sequence.append({ + 'frame': real_frame_idx + 1, + # 'x': bbox['x'] / width * 100, + # 'y': bbox['y'] / height * 100, + # 'width': bbox['width'] / width * 100, + # 'height': bbox['height'] / height * 100, + 'x': bbox['x'], + 'y': bbox['y'], + 'width': bbox['width'], + 'height': bbox['height'], + 'enabled': True, + 'rotation': 0, + 'time': out_frame_idx / fps + }) + + context_result_sequence = context['result'][0]['value']['sequence'] + + prediction = PredictionValue( + result=[{ + 'value': { + 'framesCount': frames_count, + 'duration': duration, + 'sequence': context_result_sequence + sequence, + }, + 'from_name': 'box', + 'to_name': 'video', + 'type': 'videorectangle', + 'origin': 'manual', + # TODO: current limitation is tracking only one object + 'id': list(all_obj_ids)[0] + }] + ) + logger.debug(f'Prediction: {prediction.model_dump()}') + + return ModelResponse(predictions=[prediction]) diff --git a/label_studio_ml/examples/segment_anything_2_video/requirements-base.txt b/label_studio_ml/examples/segment_anything_2_video/requirements-base.txt new file mode 100644 index 000000000..68ce357c7 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/requirements-base.txt @@ -0,0 +1,2 @@ +gunicorn==22.0.0 +label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git \ No newline at end of file diff --git a/label_studio_ml/examples/segment_anything_2_video/requirements-test.txt b/label_studio_ml/examples/segment_anything_2_video/requirements-test.txt new file mode 100644 index 000000000..cffeec658 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/requirements-test.txt @@ -0,0 +1,2 @@ +pytest +pytest-cov \ No newline at end of file diff --git a/label_studio_ml/examples/segment_anything_2_video/requirements.txt b/label_studio_ml/examples/segment_anything_2_video/requirements.txt new file mode 100644 index 000000000..473f15715 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/requirements.txt @@ -0,0 +1,2 @@ +opencv-python +cuda-python \ No newline at end of file diff --git a/label_studio_ml/examples/segment_anything_2_video/start.sh b/label_studio_ml/examples/segment_anything_2_video/start.sh new file mode 100755 index 000000000..449c16e31 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/start.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +# Execute the gunicorn command +exec gunicorn --bind :${PORT:-9090} --workers ${WORKERS:-1} --threads ${THREADS:-4} --timeout 0 _wsgi:app diff --git a/label_studio_ml/examples/segment_anything_2_video/test_api.py b/label_studio_ml/examples/segment_anything_2_video/test_api.py new file mode 100644 index 000000000..ca7767be1 --- /dev/null +++ b/label_studio_ml/examples/segment_anything_2_video/test_api.py @@ -0,0 +1,47 @@ +""" +This file contains tests for the API of your model. You can run these tests by installing test requirements: + + ```bash + pip install -r requirements-test.txt + ``` +Then execute `pytest` in the directory of this file. + +- Change `NewModel` to the name of the class in your model.py file. +- Change the `request` and `expected_response` variables to match the input and output of your model. +""" + +import pytest +import json +from model import NewModel + + +@pytest.fixture +def client(): + from _wsgi import init_app + app = init_app(model_class=NewModel) + app.config['TESTING'] = True + with app.test_client() as client: + yield client + + +def test_predict(client): + request = { + 'tasks': [{ + 'data': { + # Your input test data here + } + }], + # Your labeling configuration here + 'label_config': '' + } + + expected_response = { + 'results': [{ + # Your expected result here + }] + } + + response = client.post('/predict', data=json.dumps(request), content_type='application/json') + assert response.status_code == 200 + response = json.loads(response.data) + assert response == expected_response