From ddb6d74f46cb6a7f096362ad6aa54b60d891fa51 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 8 May 2024 10:48:37 +0200 Subject: [PATCH] initial --- .github/ISSUE_TEMPLATE/bug_report.yml | 76 +++ .github/ISSUE_TEMPLATE/config.yml | 5 + .github/ISSUE_TEMPLATE/feature_request.yml | 33 + .github/dependabot.yml | 12 + .github/release.yml | 24 + .github/verify_pr_labels.py | 87 +++ .github/workflows/builds.yml | 72 ++ .github/workflows/doc-status.yml | 22 + .github/workflows/docker.yml | 36 + .github/workflows/docs.yml | 51 ++ .github/workflows/main.yml | 122 ++++ .github/workflows/pr-labels.yml | 29 + .github/workflows/public_docker_images.yml | 86 +++ .github/workflows/publish.yml | 113 ++++ .github/workflows/pull_requests.yml | 32 + .github/workflows/style.yml | 55 ++ .gitignore | 9 + .pre-commit-config.yaml | 23 + CODE_OF_CONDUCT.md | 128 ++++ Makefile | 22 + README.md | 6 + onnxtr/__init__.py | 2 + onnxtr/contrib/__init__.py | 0 onnxtr/contrib/artefacts.py | 131 ++++ onnxtr/contrib/base.py | 105 +++ onnxtr/file_utils.py | 33 + onnxtr/io/__init__.py | 5 + onnxtr/io/elements.py | 460 +++++++++++++ onnxtr/io/html.py | 28 + onnxtr/io/image/__init__.py | 1 + onnxtr/io/image/base.py | 56 ++ onnxtr/io/pdf.py | 42 ++ onnxtr/io/reader.py | 85 +++ onnxtr/models/__init__.py | 4 + onnxtr/models/_utils.py | 141 ++++ onnxtr/models/builder.py | 355 ++++++++++ onnxtr/models/classification/__init__.py | 2 + .../models/classification/models/__init__.py | 1 + .../models/classification/models/mobilenet.py | 125 ++++ .../classification/predictor/__init__.py | 1 + .../models/classification/predictor/base.py | 59 ++ onnxtr/models/classification/zoo.py | 93 +++ onnxtr/models/detection/__init__.py | 2 + onnxtr/models/detection/_utils/__init__.py | 1 + onnxtr/models/detection/_utils/base.py | 41 ++ onnxtr/models/detection/core.py | 101 +++ onnxtr/models/detection/models/__init__.py | 3 + .../models/differentiable_binarization.py | 158 +++++ onnxtr/models/detection/models/fast.py | 158 +++++ onnxtr/models/detection/models/linknet.py | 158 +++++ .../detection/postprocessor/__init__.py | 0 onnxtr/models/detection/postprocessor/base.py | 144 ++++ onnxtr/models/detection/predictor/__init__.py | 1 + onnxtr/models/detection/predictor/base.py | 57 ++ onnxtr/models/detection/zoo.py | 73 ++ onnxtr/models/engine.py | 30 + onnxtr/models/predictor/__init__.py | 1 + onnxtr/models/predictor/base.py | 170 +++++ onnxtr/models/predictor/predictor.py | 145 ++++ onnxtr/models/preprocessor/__init__.py | 1 + onnxtr/models/preprocessor/base.py | 114 ++++ onnxtr/models/recognition/__init__.py | 2 + onnxtr/models/recognition/core.py | 28 + onnxtr/models/recognition/models/__init__.py | 5 + onnxtr/models/recognition/models/crnn.py | 225 +++++++ onnxtr/models/recognition/models/master.py | 144 ++++ onnxtr/models/recognition/models/parseq.py | 130 ++++ onnxtr/models/recognition/models/sar.py | 133 ++++ onnxtr/models/recognition/models/vitstr.py | 162 +++++ .../models/recognition/predictor/__init__.py | 1 + onnxtr/models/recognition/predictor/_utils.py | 86 +++ onnxtr/models/recognition/predictor/base.py | 80 +++ onnxtr/models/recognition/utils.py | 89 +++ onnxtr/models/recognition/zoo.py | 69 ++ onnxtr/models/zoo.py | 114 ++++ onnxtr/transforms/__init__.py | 1 + onnxtr/transforms/base.py | 102 +++ onnxtr/utils/__init__.py | 4 + onnxtr/utils/common_types.py | 18 + onnxtr/utils/data.py | 126 ++++ onnxtr/utils/fonts.py | 41 ++ onnxtr/utils/geometry.py | 456 +++++++++++++ onnxtr/utils/multithreading.py | 50 ++ onnxtr/utils/reconstitution.py | 126 ++++ onnxtr/utils/repr.py | 64 ++ onnxtr/utils/visualization.py | 388 +++++++++++ onnxtr/utils/vocabs.py | 71 ++ pyproject.toml | 181 +++++ setup.py | 23 + tests/common/test_contrib.py | 37 ++ tests/common/test_core.py | 14 + tests/common/test_headers.py | 23 + tests/common/test_io.py | 99 +++ tests/common/test_io_elements.py | 283 ++++++++ tests/common/test_models.py | 71 ++ tests/common/test_models_builder.py | 123 ++++ tests/common/test_models_detection.py | 60 ++ .../test_models_recognition_predictor.py | 39 ++ tests/common/test_models_recognition_utils.py | 31 + tests/common/test_transforms.py | 1 + tests/common/test_utils_data.py | 46 ++ tests/common/test_utils_fonts.py | 10 + tests/common/test_utils_geometry.py | 247 +++++++ tests/common/test_utils_multithreading.py | 31 + tests/common/test_utils_reconstitution.py | 12 + tests/common/test_utils_visualization.py | 32 + tests/conftest.py | 132 ++++ tests/pytorch/test_datasets_pt.py | 623 ++++++++++++++++++ tests/pytorch/test_file_utils_pt.py | 5 + tests/pytorch/test_io_image_pt.py | 52 ++ .../pytorch/test_models_classification_pt.py | 194 ++++++ tests/pytorch/test_models_detection_pt.py | 187 ++++++ tests/pytorch/test_models_factory.py | 69 ++ tests/pytorch/test_models_preprocessor_pt.py | 46 ++ tests/pytorch/test_models_recognition_pt.py | 155 +++++ tests/pytorch/test_models_utils_pt.py | 65 ++ tests/pytorch/test_models_zoo_pt.py | 327 +++++++++ tests/pytorch/test_transforms_pt.py | 351 ++++++++++ tests/tensorflow/test_datasets_loader_tf.py | 75 +++ tests/tensorflow/test_datasets_tf.py | 605 +++++++++++++++++ tests/tensorflow/test_file_utils_tf.py | 5 + tests/tensorflow/test_io_image_tf.py | 52 ++ .../test_models_classification_tf.py | 227 +++++++ tests/tensorflow/test_models_detection_tf.py | 270 ++++++++ tests/tensorflow/test_models_factory.py | 70 ++ .../tensorflow/test_models_preprocessor_tf.py | 43 ++ .../tensorflow/test_models_recognition_tf.py | 233 +++++++ tests/tensorflow/test_models_utils_tf.py | 60 ++ tests/tensorflow/test_models_zoo_tf.py | 325 +++++++++ tests/tensorflow/test_transforms_tf.py | 492 ++++++++++++++ 130 files changed, 12871 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/feature_request.yml create mode 100644 .github/dependabot.yml create mode 100644 .github/release.yml create mode 100644 .github/verify_pr_labels.py create mode 100644 .github/workflows/builds.yml create mode 100644 .github/workflows/doc-status.yml create mode 100644 .github/workflows/docker.yml create mode 100644 .github/workflows/docs.yml create mode 100644 .github/workflows/main.yml create mode 100644 .github/workflows/pr-labels.yml create mode 100644 .github/workflows/public_docker_images.yml create mode 100644 .github/workflows/publish.yml create mode 100644 .github/workflows/pull_requests.yml create mode 100644 .github/workflows/style.yml create mode 100644 .pre-commit-config.yaml create mode 100644 CODE_OF_CONDUCT.md create mode 100644 Makefile create mode 100644 onnxtr/__init__.py create mode 100644 onnxtr/contrib/__init__.py create mode 100644 onnxtr/contrib/artefacts.py create mode 100644 onnxtr/contrib/base.py create mode 100644 onnxtr/file_utils.py create mode 100644 onnxtr/io/__init__.py create mode 100644 onnxtr/io/elements.py create mode 100644 onnxtr/io/html.py create mode 100644 onnxtr/io/image/__init__.py create mode 100644 onnxtr/io/image/base.py create mode 100644 onnxtr/io/pdf.py create mode 100644 onnxtr/io/reader.py create mode 100644 onnxtr/models/__init__.py create mode 100644 onnxtr/models/_utils.py create mode 100644 onnxtr/models/builder.py create mode 100644 onnxtr/models/classification/__init__.py create mode 100644 onnxtr/models/classification/models/__init__.py create mode 100644 onnxtr/models/classification/models/mobilenet.py create mode 100644 onnxtr/models/classification/predictor/__init__.py create mode 100644 onnxtr/models/classification/predictor/base.py create mode 100644 onnxtr/models/classification/zoo.py create mode 100644 onnxtr/models/detection/__init__.py create mode 100644 onnxtr/models/detection/_utils/__init__.py create mode 100644 onnxtr/models/detection/_utils/base.py create mode 100644 onnxtr/models/detection/core.py create mode 100644 onnxtr/models/detection/models/__init__.py create mode 100644 onnxtr/models/detection/models/differentiable_binarization.py create mode 100644 onnxtr/models/detection/models/fast.py create mode 100644 onnxtr/models/detection/models/linknet.py create mode 100644 onnxtr/models/detection/postprocessor/__init__.py create mode 100644 onnxtr/models/detection/postprocessor/base.py create mode 100644 onnxtr/models/detection/predictor/__init__.py create mode 100644 onnxtr/models/detection/predictor/base.py create mode 100644 onnxtr/models/detection/zoo.py create mode 100644 onnxtr/models/engine.py create mode 100644 onnxtr/models/predictor/__init__.py create mode 100644 onnxtr/models/predictor/base.py create mode 100644 onnxtr/models/predictor/predictor.py create mode 100644 onnxtr/models/preprocessor/__init__.py create mode 100644 onnxtr/models/preprocessor/base.py create mode 100644 onnxtr/models/recognition/__init__.py create mode 100644 onnxtr/models/recognition/core.py create mode 100644 onnxtr/models/recognition/models/__init__.py create mode 100644 onnxtr/models/recognition/models/crnn.py create mode 100644 onnxtr/models/recognition/models/master.py create mode 100644 onnxtr/models/recognition/models/parseq.py create mode 100644 onnxtr/models/recognition/models/sar.py create mode 100644 onnxtr/models/recognition/models/vitstr.py create mode 100644 onnxtr/models/recognition/predictor/__init__.py create mode 100644 onnxtr/models/recognition/predictor/_utils.py create mode 100644 onnxtr/models/recognition/predictor/base.py create mode 100644 onnxtr/models/recognition/utils.py create mode 100644 onnxtr/models/recognition/zoo.py create mode 100644 onnxtr/models/zoo.py create mode 100644 onnxtr/transforms/__init__.py create mode 100644 onnxtr/transforms/base.py create mode 100644 onnxtr/utils/__init__.py create mode 100644 onnxtr/utils/common_types.py create mode 100644 onnxtr/utils/data.py create mode 100644 onnxtr/utils/fonts.py create mode 100644 onnxtr/utils/geometry.py create mode 100644 onnxtr/utils/multithreading.py create mode 100644 onnxtr/utils/reconstitution.py create mode 100644 onnxtr/utils/repr.py create mode 100644 onnxtr/utils/visualization.py create mode 100644 onnxtr/utils/vocabs.py create mode 100644 pyproject.toml create mode 100644 setup.py create mode 100644 tests/common/test_contrib.py create mode 100644 tests/common/test_core.py create mode 100644 tests/common/test_headers.py create mode 100644 tests/common/test_io.py create mode 100644 tests/common/test_io_elements.py create mode 100644 tests/common/test_models.py create mode 100644 tests/common/test_models_builder.py create mode 100644 tests/common/test_models_detection.py create mode 100644 tests/common/test_models_recognition_predictor.py create mode 100644 tests/common/test_models_recognition_utils.py create mode 100644 tests/common/test_transforms.py create mode 100644 tests/common/test_utils_data.py create mode 100644 tests/common/test_utils_fonts.py create mode 100644 tests/common/test_utils_geometry.py create mode 100644 tests/common/test_utils_multithreading.py create mode 100644 tests/common/test_utils_reconstitution.py create mode 100644 tests/common/test_utils_visualization.py create mode 100644 tests/conftest.py create mode 100644 tests/pytorch/test_datasets_pt.py create mode 100644 tests/pytorch/test_file_utils_pt.py create mode 100644 tests/pytorch/test_io_image_pt.py create mode 100644 tests/pytorch/test_models_classification_pt.py create mode 100644 tests/pytorch/test_models_detection_pt.py create mode 100644 tests/pytorch/test_models_factory.py create mode 100644 tests/pytorch/test_models_preprocessor_pt.py create mode 100644 tests/pytorch/test_models_recognition_pt.py create mode 100644 tests/pytorch/test_models_utils_pt.py create mode 100644 tests/pytorch/test_models_zoo_pt.py create mode 100644 tests/pytorch/test_transforms_pt.py create mode 100644 tests/tensorflow/test_datasets_loader_tf.py create mode 100644 tests/tensorflow/test_datasets_tf.py create mode 100644 tests/tensorflow/test_file_utils_tf.py create mode 100644 tests/tensorflow/test_io_image_tf.py create mode 100644 tests/tensorflow/test_models_classification_tf.py create mode 100644 tests/tensorflow/test_models_detection_tf.py create mode 100644 tests/tensorflow/test_models_factory.py create mode 100644 tests/tensorflow/test_models_preprocessor_tf.py create mode 100644 tests/tensorflow/test_models_recognition_tf.py create mode 100644 tests/tensorflow/test_models_utils_tf.py create mode 100644 tests/tensorflow/test_models_zoo_tf.py create mode 100644 tests/tensorflow/test_transforms_tf.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..7cd34a6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,76 @@ +name: 🐛 Bug report +description: Create a report to help us improve the library +labels: 'type: bug' + +body: +- type: markdown + attributes: + value: > + #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/mindee/onnxtr/issues?q=is%3Aissue). +- type: textarea + attributes: + label: Bug description + description: | + A clear and concise description of what the bug is. + + Please explain the result you observed and the behavior you were expecting. + placeholder: | + A clear and concise description of what the bug is. + validations: + required: true + +- type: textarea + attributes: + label: Code snippet to reproduce the bug + description: | + Sample code to reproduce the problem. + + Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ```python + Sample code to reproduce the problem + ``` + validations: + required: true +- type: textarea + attributes: + label: Error traceback + description: | + The error message you received running the code snippet, with the full traceback. + + Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please run the following command and paste the output below. + ```sh + wget https://raw.githubusercontent.com/mindee/onnxtr/main/scripts/collect_env.py + # For security purposes, please check the contents of collect_env.py before running it. + python collect_env.py + ``` + validations: + required: true +- type: textarea + attributes: + label: Deep Learning backend + description: | + Please run the following snippet and paste the output below. + ```python + from onnxtr.file_utils import is_tf_available, is_torch_available + + print(f"is_tf_available: {is_tf_available()}") + print(f"is_torch_available: {is_torch_available()}") + ``` + validations: + required: true +- type: markdown + attributes: + value: > + Thanks for helping us improve the library! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..b80828e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: Usage questions + url: https://github.com/mindee/onnxtr/discussions + about: Ask questions and discuss with other OnnxTR community members diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..8f47ec7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,33 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new feature for OnnxTR +labels: 'type: enhancement' + +body: +- type: textarea + attributes: + label: 🚀 The feature + description: > + A clear and concise description of the feature proposal + validations: + required: true +- type: textarea + attributes: + label: Motivation, pitch + description: > + Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..2c69a5b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + open-pull-requests-limit: 10 + target-branch: "main" + labels: ["topic: build"] + schedule: + interval: weekly + day: sunday + reviewers: + - "felixdittrich92" diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000..2efdfdf --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,24 @@ +changelog: + exclude: + labels: + - ignore-for-release + categories: + - title: Breaking Changes 🛠 + labels: + - "type: breaking change" + # NEW FEATURES + - title: New Features + labels: + - "type: new feature" + # BUG FIXES + - title: Bug Fixes + labels: + - "type: bug" + # IMPROVEMENTS + - title: Improvements + labels: + - "type: enhancement" + # MISC + - title: Miscellaneous + labels: + - "type: misc" diff --git a/.github/verify_pr_labels.py b/.github/verify_pr_labels.py new file mode 100644 index 0000000..37869a3 --- /dev/null +++ b/.github/verify_pr_labels.py @@ -0,0 +1,87 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +"""Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py +This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in +'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, +this script is a no-op. +Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision +with no labeling responsibility, so we don't want to bother them. +""" + +from typing import Any, Set, Tuple + +import requests + +# For a PR to be properly labeled it should have one primary label and one secondary label + +# Should specify the type of change +PRIMARY_LABELS = { + "type: new feature", + "type: bug", + "type: enhancement", + "type: misc", +} + +# Should specify what has been modified +SECONDARY_LABELS = { + "topic: documentation", + "module: datasets", + "module: io", + "module: models", + "module: transforms", + "module: utils", + "ext: api", + "ext: demo", + "ext: docs", + "ext: notebooks", + "ext: references", + "ext: scripts", + "ext: tests", + "topic: build", + "topic: ci", + "topic: docker", +} + +GH_ORG = "felixdittrich92" +GH_REPO = "onnxtr" + + +def query_repo(cmd: str, *, accept) -> Any: + response = requests.get(f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers=dict(Accept=accept)) + return response.json() + + +def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]: + # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request + data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json") + merger = data.get("merged_by", {}).get("login") + labels = {label["name"] for label in data["labels"]} + return merger, labels + + +def main(args): + merger, labels = get_pr_merger_and_labels(args.pr) + is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels)) + if isinstance(merger, str) and not is_properly_labeled: + print(f"@{merger}") + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("pr", type=int, help="PR number") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml new file mode 100644 index 0000000..6897e62 --- /dev/null +++ b/.github/workflows/builds.yml @@ -0,0 +1,72 @@ +name: builds + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python: ["3.9", "3.10"] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v4 + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - if: matrix.framework == 'tensorflow' + name: Install package (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf,viz,html] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch,viz,html] --upgrade + - name: Import package + run: python -c "import onnxtr; print(onnxtr.__version__)" + + conda: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + python-version: 3.9 + channels: pypdfium2-team,bblanchon,defaults,conda-forge + channel-priority: strict + - name: Install dependencies + shell: bash -el {0} + run: conda install -y conda-build conda-verify anaconda-client + - name: Build and verify + shell: bash -el {0} + run: | + python setup.py sdist + mkdir conda-dist + conda build .conda/ --output-folder conda-dist + conda-verify conda-dist/linux-64/*tar.bz2 --ignore=C1115 diff --git a/.github/workflows/doc-status.yml b/.github/workflows/doc-status.yml new file mode 100644 index 0000000..294f3dc --- /dev/null +++ b/.github/workflows/doc-status.yml @@ -0,0 +1,22 @@ +name: doc-status +on: + page_build + +jobs: + see-page-build-payload: + runs-on: ubuntu-latest + steps: + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.9" + architecture: x64 + - name: check status + run: | + import os + status, errormsg = os.getenv('STATUS'), os.getenv('ERROR') + if status != 'built': raise AssertionError(f"There was an error building the page on GitHub pages.\n\nStatus: {status}\n\nError messsage: {errormsg}") + shell: python + env: + STATUS: ${{ github.event.build.status }} + ERROR: ${{ github.event.build.error.message }} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000..70302c9 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,36 @@ +name: docker + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + docker-package: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build docker image + run: docker build -t doctr-tf-py3.9-slim --build-arg SYSTEM=cpu . + - name: Run docker container + run: docker run doctr-tf-py3.9-slim python3 -c 'import doctr' + + pytest-api: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Build & run docker + run: cd api && make lock && make run + - name: Ping server + run: wget --spider --tries=12 http://localhost:8080/docs + - name: Run docker test + run: cd api && make test diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..79965b5 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,51 @@ +name: docs +on: + push: + branches: main + +jobs: + docs-deploy: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf,viz,html] + pip install -e .[docs] + + - name: Build documentation + run: cd docs && bash build.sh + + - name: Documentation sanity check + run: test -e docs/build/index.html || exit + + - name: Install SSH Client 🔑 + uses: webfactory/ssh-agent@v0.4.1 + with: + ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }} + + - name: Deploy to Github Pages + uses: JamesIves/github-pages-deploy-action@3.7.1 + with: + BRANCH: gh-pages + FOLDER: 'docs/build' + COMMIT_MESSAGE: '[skip ci] Documentation updates' + CLEAN: true + SSH: true diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..ac34c9c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,122 @@ +name: tests + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + pytest-common: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-tests + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf,viz,html] --upgrade + pip install -e .[testing] + - name: Run unittests + run: | + coverage run -m pytest tests/common/ -rs + coverage xml -o coverage-common.xml + - uses: actions/upload-artifact@v4 + with: + name: coverage-common + path: ./coverage-common.xml + if-no-files-found: error + + pytest-tf: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-tests + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf,viz,html] --upgrade + pip install -e .[testing] + - name: Run unittests + run: | + coverage run -m pytest tests/tensorflow/ -rs + coverage xml -o coverage-tf.xml + - uses: actions/upload-artifact@v4 + with: + name: coverage-tf + path: ./coverage-tf.xml + if-no-files-found: error + + pytest-torch: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-tests + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[torch,viz,html] --upgrade + pip install -e .[testing] + + - name: Run unittests + run: | + coverage run -m pytest tests/pytorch/ -rs + coverage xml -o coverage-pt.xml + + - uses: actions/upload-artifact@v4 + with: + name: coverage-pytorch + path: ./coverage-pt.xml + if-no-files-found: error + + codecov-upload: + runs-on: ubuntu-latest + needs: [ pytest-common, pytest-tf, pytest-torch ] + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + flags: unittests + fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/pr-labels.yml b/.github/workflows/pr-labels.yml new file mode 100644 index 0000000..6c79591 --- /dev/null +++ b/.github/workflows/pr-labels.yml @@ -0,0 +1,29 @@ +name: pr-labels + +on: + pull_request: + branches: main + types: closed + +jobs: + is-properly-labeled: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up python + uses: actions/setup-python@v5 + - name: Install requests + run: pip install requests + - name: Process commit and find merger responsible for labeling + id: commit + run: echo "::set-output name=merger::$(python .github/verify_pr_labels.py ${{ github.event.pull_request.number }})" + - name: 'Comment PR' + uses: actions/github-script@0.3.0 + if: ${{ steps.commit.outputs.merger != '' }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { issue: { number: issue_number }, repo: { owner, repo } } = context; + github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/mindee/onnxtr/blob/main/.github/verify_pr_labels.py' }); diff --git a/.github/workflows/public_docker_images.yml b/.github/workflows/public_docker_images.yml new file mode 100644 index 0000000..2ccdb66 --- /dev/null +++ b/.github/workflows/public_docker_images.yml @@ -0,0 +1,86 @@ +# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages +# +name: Docker image on ghcr.io + +on: + push: + tags: + - 'v*' + pull_request: + branches: main + schedule: + - cron: '0 2 29 * *' # At 02:00 on day-of-month 29 + +env: + REGISTRY: ghcr.io + +jobs: + build-and-push-image: + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + # Must match version at https://www.python.org/ftp/python/ + python: ["3.9.18", "3.10.13", "3.11.8"] + framework: ["tf", "torch", "tf,viz,html,contrib", "torch,viz,html,contrib"] + system: ["cpu", "gpu"] + + # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Log in to the Container registry + uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 + with: + images: ${{ env.REGISTRY }}/${{ github.repository }} + tags: | + # used only on schedule event + type=schedule,pattern={{date 'YYYY-MM'}},prefix=${{ matrix.framework }}-py${{ matrix.python }}-${{ matrix.system }}- + # used only if a tag following semver is published + type=semver,pattern={{raw}},prefix=${{ matrix.framework }}-py${{ matrix.python }}-${{ matrix.system }}- + + - name: Build Docker image + id: build + uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 + with: + context: . + build-args: | + FRAMEWORK=${{ matrix.framework }} + PYTHON_VERSION=${{ matrix.python }} + SYSTEM=${{ matrix.system }} + DOCTR_REPO=${{ github.repository }} + DOCTR_VERSION=${{ github.sha }} + push: false # push only if `import doctr` works + tags: ${{ steps.meta.outputs.tags }} + + - name: Check if `import doctr` works + run: docker run ${{ steps.build.outputs.imageid }} python3 -c 'import doctr' + + - name: Push Docker image + # Push only if the CI is not triggered by "PR on main" + if: github.ref == 'refs/heads/main' && github.event_name != 'pull_request' + uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 + with: + context: . + build-args: | + FRAMEWORK=${{ matrix.framework }} + PYTHON_VERSION=${{ matrix.python }} + SYSTEM=${{ matrix.system }} + DOCTR_REPO=${{ github.repository }} + DOCTR_VERSION=${{ github.sha }} + push: true + tags: ${{ steps.meta.outputs.tags }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..3fd0b7c --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,113 @@ +name: publish + +on: + release: + types: [published] + +jobs: + pypi: + if: "!github.event.release.prerelease" + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.9"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine --upgrade + - name: Get release tag + id: release_tag + run: | + echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + VERSION: ${{ steps.release_tag.outputs.VERSION }} + run: | + BUILD_VERSION=$VERSION python setup.py sdist bdist_wheel + twine check dist/* + twine upload dist/* + + pypi-check: + needs: pypi + if: "!github.event.release.prerelease" + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.9"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Install package + run: | + python -m pip install --upgrade pip + pip install python-doctr[torch] + python -c "import doctr; print(doctr.__version__)" + + conda: + if: "!github.event.release.prerelease" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + python-version: 3.9 + channels: pypdfium2-team,bblanchon,defaults,conda-forge + channel-priority: strict + - name: Install dependencies + shell: bash -el {0} + run: conda install -y conda-build conda-verify anaconda-client + - name: Get release tag + id: release_tag + run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} + - name: Build and publish + shell: bash -el {0} + env: + ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }} + VERSION: ${{ steps.release_tag.outputs.VERSION }} + run: | + echo "BUILD_VERSION=${VERSION}" >> $GITHUB_ENV + python setup.py sdist + mkdir conda-dist + conda build .conda/ --output-folder conda-dist + conda-verify conda-dist/linux-64/*tar.bz2 --ignore=C1115 + anaconda upload conda-dist/linux-64/*tar.bz2 + + conda-check: + if: "!github.event.release.prerelease" + runs-on: ubuntu-latest + needs: conda + steps: + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + python-version: 3.9 + - name: Install package + shell: bash -el {0} + run: | + conda config --set channel_priority strict + conda install pytorch torchvision torchaudio cpuonly -c pytorch + conda install -c techMindee -c pypdfium2-team -c bblanchon -c defaults -c conda-forge python-doctr + python -c "import doctr; print(doctr.__version__)" diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml new file mode 100644 index 0000000..045c467 --- /dev/null +++ b/.github/workflows/pull_requests.yml @@ -0,0 +1,32 @@ +name: pull_requests + +on: + pull_request: + branches: main + +jobs: + docs-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.9" + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf,viz,html] --upgrade + pip install -e .[docs] + + - name: Build documentation + run: cd docs && bash build.sh + + - name: Documentation sanity check + run: test -e docs/build/index.html || exit diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml new file mode 100644 index 0000000..c755a06 --- /dev/null +++ b/.github/workflows/style.yml @@ -0,0 +1,55 @@ +name: style + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + ruff: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run ruff + run: | + pip install ruff --upgrade + ruff --version + ruff check --diff . + + mypy: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[dev] --upgrade + pip install mypy --upgrade + - name: Run mypy + run: | + mypy --version + mypy diff --git a/.gitignore b/.gitignore index b6e4761..afabe7e 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,12 @@ dmypy.json # Pyre type checker .pyre/ + +# Temp files +onnxtr/version.py +logs/ +wandb/ +.idea/ + +# Model files +*.onnx diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8e66bb4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-yaml + exclude: .conda + - id: check-toml + - id: check-json + - id: check-added-large-files + exclude: docs/images/ + - id: end-of-file-fixer + - id: trailing-whitespace + - id: debug-statements + - id: check-merge-conflict + - id: no-commit-to-branch + args: ['--branch', 'main'] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.2 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..ee84f1d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +contact@mindee.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3d79eb6 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +.PHONY: quality style test docs-single-version docs +# this target runs checks on all files +quality: + ruff check . + mypy onnxtr/ + +# this target runs checks on all files and potentially modifies some of them +style: + ruff format . + ruff check --fix . + +# Run tests for the library +test: + coverage run -m pytest tests/common/ -rs + +# Check that docs can build +docs-single-version: + sphinx-build docs/source docs/_build -a + +# Check that docs can build +docs: + cd docs && bash build.sh \ No newline at end of file diff --git a/README.md b/README.md index 35eb4a9..787daff 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,8 @@ # onnxcr Todo + + +- tests +- actions +- readme +- testen diff --git a/onnxtr/__init__.py b/onnxtr/__init__.py new file mode 100644 index 0000000..44fb4e3 --- /dev/null +++ b/onnxtr/__init__.py @@ -0,0 +1,2 @@ +from . import io, models, contrib, transforms, utils +from .version import __version__ # noqa: F401 diff --git a/onnxtr/contrib/__init__.py b/onnxtr/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onnxtr/contrib/artefacts.py b/onnxtr/contrib/artefacts.py new file mode 100644 index 0000000..f3e51ac --- /dev/null +++ b/onnxtr/contrib/artefacts.py @@ -0,0 +1,131 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + +from onnxtr.file_utils import requires_package + +from .base import _BasePredictor + +__all__ = ["ArtefactDetector"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "yolov8_artefact": { + "input_shape": (3, 1024, 1024), + "labels": ["bar_code", "qr_code", "logo", "photo"], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0", + }, +} + + +class ArtefactDetector(_BasePredictor): + """ + A class to detect artefacts in images + + >>> from onnxtr.io import DocumentFile + >>> from onnxtr.contrib.artefacts import ArtefactDetector + >>> doc = DocumentFile.from_images(["path/to/image.jpg"]) + >>> detector = ArtefactDetector() + >>> results = detector(doc) + + Args: + ---- + arch: the architecture to use + batch_size: the batch size to use + model_path: the path to the model to use + labels: the labels to use + input_shape: the input shape to use + mask_labels: the mask labels to use + conf_threshold: the confidence threshold to use + iou_threshold: the intersection over union threshold to use + **kwargs: additional arguments to be passed to `download_from_url` + """ + + def __init__( + self, + arch: str = "yolov8_artefact", + batch_size: int = 2, + model_path: Optional[str] = None, + labels: Optional[List[str]] = None, + input_shape: Optional[Tuple[int, int, int]] = None, + conf_threshold: float = 0.5, + iou_threshold: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs) + self.labels = labels or default_cfgs[arch]["labels"] + self.input_shape = input_shape or default_cfgs[arch]["input_shape"] + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + + def preprocess(self, img: np.ndarray) -> np.ndarray: + return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0) + + def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]: + results = [] + + for batch in zip(output, input_images): + for out, img in zip(batch[0], batch[1]): + org_height, org_width = img.shape[:2] + width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1] + for res in out: + sample_results = [] + for row in np.transpose(np.squeeze(res)): + classes_scores = row[4:] + max_score = np.amax(classes_scores) + if max_score >= self.conf_threshold: + class_id = np.argmax(classes_scores) + x, y, w, h = row[0], row[1], row[2], row[3] + # to rescaled xmin, ymin, xmax, ymax + xmin = int((x - w / 2) * width_scale) + ymin = int((y - h / 2) * height_scale) + xmax = int((x + w / 2) * width_scale) + ymax = int((y + h / 2) * height_scale) + + sample_results.append({ + "label": self.labels[class_id], + "confidence": float(max_score), + "box": [xmin, ymin, xmax, ymax], + }) + + # Filter out overlapping boxes + boxes = [res["box"] for res in sample_results] + scores = [res["confidence"] for res in sample_results] + keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type] + sample_results = [sample_results[i] for i in keep_indices] + + results.append(sample_results) + + self._results = results + return results + + def show(self, **kwargs: Any) -> None: + """ + Display the results + + Args: + ---- + **kwargs: additional keyword arguments to be passed to `plt.show` + """ + requires_package("matplotlib", "`.show()` requires matplotlib installed") + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + # visualize the results with matplotlib + if self._results and self._inputs: + for img, res in zip(self._inputs, self._results): + plt.figure(figsize=(10, 10)) + plt.imshow(img) + for obj in res: + xmin, ymin, xmax, ymax = obj["box"] + label = obj["label"] + plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red") + plt.gca().add_patch( + Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2) + ) + plt.show(**kwargs) diff --git a/onnxtr/contrib/base.py b/onnxtr/contrib/base.py new file mode 100644 index 0000000..08eb449 --- /dev/null +++ b/onnxtr/contrib/base.py @@ -0,0 +1,105 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Optional + +import numpy as np + +from onnxtr.file_utils import requires_package +from onnxtr.utils.data import download_from_url + + +class _BasePredictor: + """ + Base class for all predictors + + Args: + ---- + batch_size: the batch size to use + url: the url to use to download a model if needed + model_path: the path to the model to use + **kwargs: additional arguments to be passed to `download_from_url` + """ + + def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None: + self.batch_size = batch_size + self.session = self._init_model(url, model_path, **kwargs) + + self._inputs: List[np.ndarray] = [] + self._results: List[Any] = [] + + def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any: + """ + Download the model from the given url if needed + + Args: + ---- + url: the url to use + model_path: the path to the model to use + **kwargs: additional arguments to be passed to `download_from_url` + + Returns: + ------- + Any: the ONNX loaded model + """ + requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") + import onnxruntime as ort + + if not url and not model_path: + raise ValueError("You must provide either a url or a model_path") + onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] + return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + + def preprocess(self, img: np.ndarray) -> np.ndarray: + """ + Preprocess the input image + + Args: + ---- + img: the input image to preprocess + + Returns: + ------- + np.ndarray: the preprocessed image + """ + raise NotImplementedError + + def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any: + """ + Postprocess the model output + + Args: + ---- + output: the model output to postprocess + input_images: the input images used to generate the output + + Returns: + ------- + Any: the postprocessed output + """ + raise NotImplementedError + + def __call__(self, inputs: List[np.ndarray]) -> Any: + """ + Call the model on the given inputs + + Args: + ---- + inputs: the inputs to use + + Returns: + ------- + Any: the postprocessed output + """ + self._inputs = inputs + model_inputs = self.session.get_inputs() + + batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)] + processed_batches = [ + np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs + ] + + outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches] + return self.postprocess(outputs, batched_inputs) diff --git a/onnxtr/file_utils.py b/onnxtr/file_utils.py new file mode 100644 index 0000000..3905a6f --- /dev/null +++ b/onnxtr/file_utils.py @@ -0,0 +1,33 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import importlib.metadata +import importlib.util +import logging +from typing import Optional + +__all__ = ["requires_package"] + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + + +def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover + """ + package requirement helper + + Args: + ---- + name: name of the package + extra_message: additional message to display if the package is not found + """ + try: + _pkg_version = importlib.metadata.version(name) + logging.info(f"{name} version {_pkg_version} available.") + except importlib.metadata.PackageNotFoundError: + raise ImportError( + f"\n\n{extra_message if extra_message is not None else ''} " + f"\nPlease install it with the following command: pip install {name}\n" + ) diff --git a/onnxtr/io/__init__.py b/onnxtr/io/__init__.py new file mode 100644 index 0000000..6eab8c2 --- /dev/null +++ b/onnxtr/io/__init__.py @@ -0,0 +1,5 @@ +from .elements import * +from .html import * +from .image import * +from .pdf import * +from .reader import * diff --git a/onnxtr/io/elements.py b/onnxtr/io/elements.py new file mode 100644 index 0000000..1e53e37 --- /dev/null +++ b/onnxtr/io/elements.py @@ -0,0 +1,460 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple, Union + +from defusedxml import defuse_stdlib + +defuse_stdlib() +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import Element as ETElement +from xml.etree.ElementTree import SubElement + +import numpy as np + +import onnxtr +from onnxtr.file_utils import requires_package +from onnxtr.utils.common_types import BoundingBox +from onnxtr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox +from onnxtr.utils.reconstitution import synthesize_page +from onnxtr.utils.repr import NestedObject + +try: # optional dependency for visualization + from onnxtr.utils.visualization import visualize_page +except ModuleNotFoundError: + pass + +__all__ = ["Element", "Word", "Artefact", "Line", "Block", "Page", "Document"] + + +class Element(NestedObject): + """Implements an abstract document element with exporting and text rendering capabilities""" + + _children_names: List[str] = [] + _exported_keys: List[str] = [] + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + if k in self._children_names: + setattr(self, k, v) + else: + raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'") + + def export(self) -> Dict[str, Any]: + """Exports the object into a nested dict format""" + export_dict = {k: getattr(self, k) for k in self._exported_keys} + for children_name in self._children_names: + if children_name in ["predictions"]: + export_dict[children_name] = { + k: [item.export() for item in c] for k, c in getattr(self, children_name).items() + } + else: + export_dict[children_name] = [c.export() for c in getattr(self, children_name)] + + return export_dict + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + raise NotImplementedError + + def render(self) -> str: + raise NotImplementedError + + +class Word(Element): + """Implements a word element + + Args: + ---- + value: the text string of the word + confidence: the confidence associated with the text prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size + crop_orientation: the general orientation of the crop in degrees and its confidence + """ + + _exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"] + _children_names: List[str] = [] + + def __init__( + self, + value: str, + confidence: float, + geometry: Union[BoundingBox, np.ndarray], + crop_orientation: Dict[str, Any], + ) -> None: + super().__init__() + self.value = value + self.confidence = confidence + self.geometry = geometry + self.crop_orientation = crop_orientation + + def render(self) -> str: + """Renders the full text of the element""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Artefact(Element): + """Implements a non-textual element + + Args: + ---- + artefact_type: the type of artefact + confidence: the confidence of the type prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. + """ + + _exported_keys: List[str] = ["geometry", "type", "confidence"] + _children_names: List[str] = [] + + def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None: + super().__init__() + self.geometry = geometry + self.type = artefact_type + self.confidence = confidence + + def render(self) -> str: + """Renders the full text of the element""" + return f"[{self.type.upper()}]" + + def extra_repr(self) -> str: + return f"type='{self.type}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Line(Element): + """Implements a line element as a collection of words + + Args: + ---- + words: list of word elements + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all words in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ["words"] + words: List[Word] = [] + + def __init__( + self, + words: List[Word], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + # Check whether this is a rotated or straight box + box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox + geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator] + + super().__init__(words=words) + self.geometry = geometry + + def render(self) -> str: + """Renders the full text of the element""" + return " ".join(w.render() for w in self.words) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + "words": [Word.from_dict(_dict) for _dict in save_dict["words"]], + }) + return cls(**kwargs) + + +class Block(Element): + """Implements a block element as a collection of lines and artefacts + + Args: + ---- + lines: list of line elements + artefacts: list of artefacts + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all lines and artefacts in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ["lines", "artefacts"] + lines: List[Line] = [] + artefacts: List[Artefact] = [] + + def __init__( + self, + lines: List[Line] = [], + artefacts: List[Artefact] = [], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + line_boxes = [word.geometry for line in lines for word in line.words] + artefact_boxes = [artefact.geometry for artefact in artefacts] + box_resolution_fn = ( + resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox + ) + geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator] + + super().__init__(lines=lines, artefacts=artefacts) + self.geometry = geometry + + def render(self, line_break: str = "\n") -> str: + """Renders the full text of the element""" + return line_break.join(line.render() for line in self.lines) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + "lines": [Line.from_dict(_dict) for _dict in save_dict["lines"]], + "artefacts": [Artefact.from_dict(_dict) for _dict in save_dict["artefacts"]], + }) + return cls(**kwargs) + + +class Page(Element): + """Implements a page element as a collection of blocks + + Args: + ---- + page: image encoded as a numpy array in uint8 + blocks: list of block elements + page_idx: the index of the page in the input raw document + dimensions: the page size in pixels in format (height, width) + orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction + language: a dictionary with the language value and confidence of the prediction + """ + + _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"] + _children_names: List[str] = ["blocks"] + blocks: List[Block] = [] + + def __init__( + self, + page: np.ndarray, + blocks: List[Block], + page_idx: int, + dimensions: Tuple[int, int], + orientation: Optional[Dict[str, Any]] = None, + language: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(blocks=blocks) + self.page = page + self.page_idx = page_idx + self.dimensions = dimensions + self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None) + self.language = language if isinstance(language, dict) else dict(value=None, confidence=None) + + def render(self, block_break: str = "\n\n") -> str: + """Renders the full text of the element""" + return block_break.join(b.render() for b in self.blocks) + + def extra_repr(self) -> str: + return f"dimensions={self.dimensions}" + + def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None: + """Overlay the result on a given image + + Args: + interactive: whether the display should be interactive + preserve_aspect_ratio: pass True if you passed True to the predictor + **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method + """ + requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") + requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") + import matplotlib.pyplot as plt + + visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) + plt.show(**kwargs) + + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Returns + ------- + synthesized page + """ + return synthesize_page(self.export(), **kwargs) + + def export_as_xml(self, file_title: str = "OnnxTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]: + """Export the page as XML (hOCR-format) + convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md + + Args: + ---- + file_title: the title of the XML file + + Returns: + ------- + a tuple of the XML byte string, and its ElementTree + """ + p_idx = self.page_idx + block_count: int = 1 + line_count: int = 1 + word_count: int = 1 + height, width = self.dimensions + language = self.language if "language" in self.language.keys() else "en" + # Create the XML root element + page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)}) + # Create the header / SubElements of the root element + head = SubElement(page_hocr, "head") + SubElement(head, "title").text = file_title + SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"}) + SubElement( + head, + "meta", + attrib={"name": "ocr-system", "content": f" {onnxtr.__version__}"}, # type: ignore[attr-defined] + ) + SubElement( + head, + "meta", + attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"}, + ) + # Create the body + body = SubElement(page_hocr, "body") + SubElement( + body, + "div", + attrib={ + "class": "ocr_page", + "id": f"page_{p_idx + 1}", + "title": f"image; bbox 0 0 {width} {height}; ppageno 0", + }, + ) + # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes + for block in self.blocks: + if len(block.geometry) != 2: + raise TypeError("XML export is only available for straight bounding boxes for now.") + (xmin, ymin), (xmax, ymax) = block.geometry + block_div = SubElement( + body, + "div", + attrib={ + "class": "ocr_carea", + "id": f"block_{block_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + paragraph = SubElement( + block_div, + "p", + attrib={ + "class": "ocr_par", + "id": f"par_{block_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + block_count += 1 + for line in block.lines: + (xmin, ymin), (xmax, ymax) = line.geometry + # NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0 + line_span = SubElement( + paragraph, + "span", + attrib={ + "class": "ocr_line", + "id": f"line_{line_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0", + }, + ) + line_count += 1 + for word in line.words: + (xmin, ymin), (xmax, ymax) = word.geometry + conf = word.confidence + word_div = SubElement( + line_span, + "span", + attrib={ + "class": "ocrx_word", + "id": f"word_{word_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + x_wconf {int(round(conf * 100))}", + }, + ) + # set the text + word_div.text = word.value + word_count += 1 + + return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]}) + return cls(**kwargs) + + +class Document(Element): + """Implements a document element as a collection of pages + + Args: + ---- + pages: list of page elements + """ + + _children_names: List[str] = ["pages"] + pages: List[Page] = [] + + def __init__( + self, + pages: List[Page], + ) -> None: + super().__init__(pages=pages) + + def render(self, page_break: str = "\n\n\n\n") -> str: + """Renders the full text of the element""" + return page_break.join(p.render() for p in self.pages) + + def show(self, **kwargs) -> None: + """Overlay the result on a given image""" + for result in self.pages: + result.show(**kwargs) + + def synthesize(self, **kwargs) -> List[np.ndarray]: + """Synthesize all pages from their predictions + + Returns + ------- + list of synthesized pages + """ + return [page.synthesize() for page in self.pages] + + def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: + """Export the document as XML (hOCR-format) + + Args: + ---- + **kwargs: additional keyword arguments passed to the Page.export_as_xml method + + Returns: + ------- + list of tuple of (bytes, ElementTree) + """ + return [page.export_as_xml(**kwargs) for page in self.pages] + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]}) + return cls(**kwargs) diff --git a/onnxtr/io/html.py b/onnxtr/io/html.py new file mode 100644 index 0000000..e3d9269 --- /dev/null +++ b/onnxtr/io/html.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +__all__ = ["read_html"] + + +def read_html(url: str, **kwargs: Any) -> bytes: + """Read a PDF file and convert it into an image in numpy format + + >>> from onnxtr.io import read_html + >>> doc = read_html("https://www.yoursite.com") + + Args: + ---- + url: URL of the target web page + **kwargs: keyword arguments from `weasyprint.HTML` + + Returns: + ------- + decoded PDF file as a bytes stream + """ + from weasyprint import HTML + + return HTML(url, **kwargs).write_pdf() diff --git a/onnxtr/io/image/__init__.py b/onnxtr/io/image/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/io/image/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/io/image/base.py b/onnxtr/io/image/base.py new file mode 100644 index 0000000..8c44e91 --- /dev/null +++ b/onnxtr/io/image/base.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from onnxtr.utils.common_types import AbstractFile + +__all__ = ["read_img_as_numpy"] + + +def read_img_as_numpy( + file: AbstractFile, + output_size: Optional[Tuple[int, int]] = None, + rgb_output: bool = True, +) -> np.ndarray: + """Read an image file into numpy format + + >>> from onnxtr.io import read_img_as_numpy + >>> page = read_img_as_numpy("path/to/your/doc.jpg") + + Args: + ---- + file: the path to the image file + output_size: the expected output size of each page in format H x W + rgb_output: whether the output ndarray channel order should be RGB instead of BGR. + + Returns: + ------- + the page decoded as numpy ndarray of shape H x W x 3 + """ + if isinstance(file, (str, Path)): + if not Path(file).is_file(): + raise FileNotFoundError(f"unable to access {file}") + img = cv2.imread(str(file), cv2.IMREAD_COLOR) + elif isinstance(file, bytes): + _file: np.ndarray = np.frombuffer(file, np.uint8) + img = cv2.imdecode(_file, cv2.IMREAD_COLOR) + else: + raise TypeError("unsupported object type for argument 'file'") + + # Validity check + if img is None: + raise ValueError("unable to read file.") + # Resizing + if isinstance(output_size, tuple): + img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR) + # Switch the channel order + if rgb_output: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img diff --git a/onnxtr/io/pdf.py b/onnxtr/io/pdf.py new file mode 100644 index 0000000..d027186 --- /dev/null +++ b/onnxtr/io/pdf.py @@ -0,0 +1,42 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Optional + +import numpy as np +import pypdfium2 as pdfium + +from onnxtr.utils.common_types import AbstractFile + +__all__ = ["read_pdf"] + + +def read_pdf( + file: AbstractFile, + scale: float = 2, + rgb_mode: bool = True, + password: Optional[str] = None, + **kwargs: Any, +) -> List[np.ndarray]: + """Read a PDF file and convert it into an image in numpy format + + >>> from onnxtr.io import read_pdf + >>> doc = read_pdf("path/to/your/doc.pdf") + + Args: + ---- + file: the path to the PDF file + scale: rendering scale (1 corresponds to 72dpi) + rgb_mode: if True, the output will be RGB, otherwise BGR + password: a password to unlock the document, if encrypted + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x C + """ + # Rasterise pages to numpy ndarrays with pypdfium2 + pdf = pdfium.PdfDocument(file, password=password, autoclose=True) + return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf] diff --git a/onnxtr/io/reader.py b/onnxtr/io/reader.py new file mode 100644 index 0000000..68290e1 --- /dev/null +++ b/onnxtr/io/reader.py @@ -0,0 +1,85 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Sequence, Union + +import numpy as np + +from onnxtr.file_utils import requires_package +from onnxtr.utils.common_types import AbstractFile + +from .html import read_html +from .image import read_img_as_numpy +from .pdf import read_pdf + +__all__ = ["DocumentFile"] + + +class DocumentFile: + """Read a document from multiple extensions""" + + @classmethod + def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]: + """Read a PDF file + + >>> from onnxtr.io import DocumentFile + >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") + + Args: + ---- + file: the path to the PDF file or a binary stream + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + return read_pdf(file, **kwargs) + + @classmethod + def from_url(cls, url: str, **kwargs) -> List[np.ndarray]: + """Interpret a web page as a PDF document + + >>> from onnxtr.io import DocumentFile + >>> doc = DocumentFile.from_url("https://www.yoursite.com") + + Args: + ---- + url: the URL of the target web page + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + requires_package( + "weasyprint", + "`.from_url` requires weasyprint installed.\n" + + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation", + ) + pdf_stream = read_html(url) + return cls.from_pdf(pdf_stream, **kwargs) + + @classmethod + def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]: + """Read an image file (or a collection of image files) and convert it into an image in numpy format + + >>> from onnxtr.io import DocumentFile + >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) + + Args: + ---- + files: the path to the image file or a binary stream, or a collection of those + **kwargs: additional parameters to :meth:`onnxtr.io.image.read_img_as_numpy` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + if isinstance(files, (str, Path, bytes)): + files = [files] + + return [read_img_as_numpy(file, **kwargs) for file in files] diff --git a/onnxtr/models/__init__.py b/onnxtr/models/__init__.py new file mode 100644 index 0000000..4e4f327 --- /dev/null +++ b/onnxtr/models/__init__.py @@ -0,0 +1,4 @@ +from .classification import * +from .detection import * +from .recognition import * +from .zoo import * diff --git a/onnxtr/models/_utils.py b/onnxtr/models/_utils.py new file mode 100644 index 0000000..2efd6ef --- /dev/null +++ b/onnxtr/models/_utils.py @@ -0,0 +1,141 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from math import floor +from statistics import median_low +from typing import List, Optional, Tuple + +import cv2 +import numpy as np +from langdetect import LangDetectException, detect_langs + +__all__ = ["estimate_orientation", "get_language"] + + +def get_max_width_length_ratio(contour: np.ndarray) -> float: + """Get the maximum shape ratio of a contour. + + Args: + ---- + contour: the contour from cv2.findContour + + Returns: + ------- + the maximum shape ratio + """ + _, (w, h), _ = cv2.minAreaRect(contour) + return max(w / h, h / w) + + +def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int: + """Estimate the angle of the general document orientation based on the + lines of the document and the assumption that they should be horizontal. + + Args: + ---- + img: the img or bitmap to analyze (H, W, C) + n_ct: the number of contours used for the orientation estimation + ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines + + Returns: + ------- + the angle of the general document orientation + """ + assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" + max_value = np.max(img) + min_value = np.min(img) + if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1): + thresh = img.astype(np.uint8) + if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3: + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = cv2.medianBlur(gray_img, 5) + thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment] + + # try to merge words in lines + (h, w) = img.shape[:2] + k_x = max(1, (floor(w / 100))) + k_y = max(1, (floor(h / 100))) + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) + thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment] + + # extract contours + contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + # Sort contours + contours = sorted(contours, key=get_max_width_length_ratio, reverse=True) + + angles = [] + for contour in contours[:n_ct]: + _, (w, h), angle = cv2.minAreaRect(contour) + if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines + angles.append(angle) + elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree + angles.append(angle - 90) + + if len(angles) == 0: + return 0 # in case no angles is found + else: + median = -median_low(angles) + return round(median) if abs(median) != 0 else 0 + + +def rectify_crops( + crops: List[np.ndarray], + orientations: List[int], +) -> List[np.ndarray]: + """Rotate each crop of the list according to the predicted orientation: + 0: already straight, no rotation + 1: 90 ccw, rotate 3 times ccw + 2: 180, rotate 2 times ccw + 3: 270 ccw, rotate 1 time ccw + """ + # Inverse predictions (if angle of +90 is detected, rotate by -90) + orientations = [4 - pred if pred != 0 else 0 for pred in orientations] + return ( + [crop if orientation == 0 else np.rot90(crop, orientation) for orientation, crop in zip(orientations, crops)] + if len(orientations) > 0 + else [] + ) + + +def rectify_loc_preds( + page_loc_preds: np.ndarray, + orientations: List[int], +) -> Optional[np.ndarray]: + """Orient the quadrangle (Polygon4P) according to the predicted orientation, + so that the points are in this order: top L, top R, bot R, bot L if the crop is readable + """ + return ( + np.stack( + [ + np.roll(page_loc_pred, orientation, axis=0) + for orientation, page_loc_pred in zip(orientations, page_loc_preds) + ], + axis=0, + ) + if len(orientations) > 0 + else None + ) + + +def get_language(text: str) -> Tuple[str, float]: + """Get languages of a text using langdetect model. + Get the language with the highest probability or no language if only a few words or a low probability + + Args: + ---- + text (str): text + + Returns: + ------- + The detected language in ISO 639 code and confidence score + """ + try: + lang = detect_langs(text.lower())[0] + except LangDetectException: + return "unknown", 0.0 + if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2): + return "unknown", 0.0 + return lang.lang, lang.prob diff --git a/onnxtr/models/builder.py b/onnxtr/models/builder.py new file mode 100644 index 0000000..4c07c47 --- /dev/null +++ b/onnxtr/models/builder.py @@ -0,0 +1,355 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from scipy.cluster.hierarchy import fclusterdata + +from onnxtr.io.elements import Block, Document, Line, Page, Word +from onnxtr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes +from onnxtr.utils.repr import NestedObject + +__all__ = ["DocumentBuilder"] + + +class DocumentBuilder(NestedObject): + """Implements a document builder + + Args: + ---- + resolve_lines: whether words should be automatically grouped into lines + resolve_blocks: whether lines should be automatically grouped into blocks + paragraph_break: relative length of the minimum space separating paragraphs + export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle + box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is. + """ + + def __init__( + self, + resolve_lines: bool = True, + resolve_blocks: bool = True, + paragraph_break: float = 0.035, + export_as_straight_boxes: bool = False, + ) -> None: + self.resolve_lines = resolve_lines + self.resolve_blocks = resolve_blocks + self.paragraph_break = paragraph_break + self.export_as_straight_boxes = export_as_straight_boxes + + @staticmethod + def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Sort bounding boxes from top to bottom, left to right + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) + + Returns: + ------- + tuple: indices of ordered boxes of shape (N,), boxes + If straight boxes are passed tpo the function, boxes are unchanged + else: boxes returned are straight boxes fitted to the straightened rotated boxes + so that we fit the lines afterwards to the straigthened page + """ + if boxes.ndim == 3: + boxes = rotate_boxes( + loc_preds=boxes, + angle=-estimate_page_angle(boxes), + orig_shape=(1024, 1024), + min_angle=5.0, + ) + boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1) + return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes + + def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]: + """Split a line in sub_lines + + Args: + ---- + boxes: bounding boxes of shape (N, 4) + word_idcs: list of indexes for the words of the line + + Returns: + ------- + A list of (sub-)lines computed from the original line (words) + """ + lines = [] + # Sort words horizontally + word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()] + + # Eventually split line horizontally + if len(word_idcs) < 2: + lines.append(word_idcs) + else: + sub_line = [word_idcs[0]] + for i in word_idcs[1:]: + horiz_break = True + + prev_box = boxes[sub_line[-1]] + # Compute distance between boxes + dist = boxes[i, 0] - prev_box[2] + # If distance between boxes is lower than paragraph break, same sub-line + if dist < self.paragraph_break: + horiz_break = False + + if horiz_break: + lines.append(sub_line) + sub_line = [] + + sub_line.append(i) + lines.append(sub_line) + + return lines + + def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: + """Order boxes to group them in lines + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox + + Returns: + ------- + nested list of box indices + """ + # Sort boxes, and straighten the boxes if they are rotated + idxs, boxes = self._sort_boxes(boxes) + + # Compute median for boxes heights + y_med = np.median(boxes[:, 3] - boxes[:, 1]) + + lines = [] + words = [idxs[0]] # Assign the top-left word to the first line + # Define a mean y-center for the line + y_center_sum = boxes[idxs[0]][[1, 3]].mean() + + for idx in idxs[1:]: + vert_break = True + + # Compute y_dist + y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words)) + # If y-center of the box is close enough to mean y-center of the line, same line + if y_dist < y_med / 2: + vert_break = False + + if vert_break: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + words = [] + y_center_sum = 0 + + words.append(idx) + y_center_sum += boxes[idx][[1, 3]].mean() + + # Use the remaining words to form the last(s) line(s) + if len(words) > 0: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + + return lines + + @staticmethod + def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]: + """Order lines to group them in blocks + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) + lines: list of lines, each line is a list of idx + + Returns: + ------- + nested list of box indices + """ + # Resolve enclosing boxes of lines + if boxes.ndim == 3: + box_lines: np.ndarray = np.asarray([ + resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc] + for line in lines + ]) + else: + _box_lines = [ + resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line]) + for line in lines + ] + box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines]) + + # Compute geometrical features of lines to clusterize + # Clusterizing only with box centers yield to poor results for complex documents + if boxes.ndim == 3: + box_features: np.ndarray = np.stack( + ( + (box_lines[:, 0, 0] + box_lines[:, 0, 1]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 2, 0] + box_lines[:, 2, 1]) / 2, + ), + axis=-1, + ) + else: + box_features = np.stack( + ( + (box_lines[:, 0] + box_lines[:, 3]) / 2, + (box_lines[:, 1] + box_lines[:, 2]) / 2, + (box_lines[:, 0] + box_lines[:, 2]) / 2, + (box_lines[:, 1] + box_lines[:, 3]) / 2, + box_lines[:, 0], + box_lines[:, 1], + ), + axis=-1, + ) + # Compute clusters + clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean") + + _blocks: Dict[int, List[int]] = {} + # Form clusters + for line_idx, cluster_idx in enumerate(clusters): + if cluster_idx in _blocks.keys(): + _blocks[cluster_idx].append(line_idx) + else: + _blocks[cluster_idx] = [line_idx] + + # Retrieve word-box level to return a fully nested structure + blocks = [[lines[idx] for idx in block] for block in _blocks.values()] + + return blocks + + def _build_blocks( + self, + boxes: np.ndarray, + word_preds: List[Tuple[str, float]], + crop_orientations: List[Dict[str, Any]], + ) -> List[Block]: + """Gather independent words in structured blocks + + Args: + ---- + boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2) + word_preds: list of all detected words of the page, of shape N + crop_orientations: list of dictoinaries containing + the general orientation (orientations + confidences) of the crops + + Returns: + ------- + list of block elements + """ + if boxes.shape[0] != len(word_preds): + raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}") + + if boxes.shape[0] == 0: + return [] + + # Decide whether we try to form lines + _boxes = boxes + if self.resolve_lines: + lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) + # Decide whether we try to form blocks + if self.resolve_blocks and len(lines) > 1: + _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines) + else: + _blocks = [lines] + else: + # Sort bounding boxes, one line for all boxes, one block for the line + lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] # type: ignore[list-item] + _blocks = [lines] + + blocks = [ + Block([ + Line([ + Word( + *word_preds[idx], + tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + crop_orientations[idx], + ) + if boxes.ndim == 3 + else Word( + *word_preds[idx], + ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])), + crop_orientations[idx], + ) + for idx in line + ]) + for line in lines + ]) + for lines in _blocks + ] + + return blocks + + def extra_repr(self) -> str: + return ( + f"resolve_lines={self.resolve_lines}, resolve_blocks={self.resolve_blocks}, " + f"paragraph_break={self.paragraph_break}, " + f"export_as_straight_boxes={self.export_as_straight_boxes}" + ) + + def __call__( + self, + pages: List[np.ndarray], + boxes: List[np.ndarray], + text_preds: List[List[Tuple[str, float]]], + page_shapes: List[Tuple[int, int]], + crop_orientations: List[Dict[str, Any]], + orientations: Optional[List[Dict[str, Any]]] = None, + languages: Optional[List[Dict[str, Any]]] = None, + ) -> Document: + """Re-arrange detected words into structured blocks + + Args: + ---- + pages: list of N elements, where each element represents the page image + boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5) + or (*, 6) for all words for a given page + text_preds: list of N elements, where each element is the list of all word prediction (text + confidence) + page_shapes: shape of each page, of size N + crop_orientations: list of N elements, where each element is + a dictionary containing the general orientation (orientations + confidences) of the crops + orientations: optional, list of N elements, + where each element is a dictionary containing the orientation (orientation + confidence) + languages: optional, list of N elements, + where each element is a dictionary containing the language (language + confidence) + + Returns: + ------- + document object + """ + if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len( + crop_orientations + ): + raise ValueError("All arguments are expected to be lists of the same size") + + _orientations = ( + orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item] + ) + _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item] + if self.export_as_straight_boxes and len(boxes) > 0: + # If boxes are already straight OK, else fit a bounding rect + if boxes[0].ndim == 3: + # Iterate over pages and boxes + boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes] + + _pages = [ + Page( + page, + self._build_blocks( + page_boxes, + word_preds, + word_crop_orientations, + ), + _idx, + shape, + orientation, + language, + ) + for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip( + pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages + ) + ] + + return Document(_pages) diff --git a/onnxtr/models/classification/__init__.py b/onnxtr/models/classification/__init__.py new file mode 100644 index 0000000..cd47940 --- /dev/null +++ b/onnxtr/models/classification/__init__.py @@ -0,0 +1,2 @@ +from .models import * +from .zoo import * diff --git a/onnxtr/models/classification/models/__init__.py b/onnxtr/models/classification/models/__init__.py new file mode 100644 index 0000000..f1fd34d --- /dev/null +++ b/onnxtr/models/classification/models/__init__.py @@ -0,0 +1 @@ +from .mobilenet import * \ No newline at end of file diff --git a/onnxtr/models/classification/models/mobilenet.py b/onnxtr/models/classification/models/mobilenet.py new file mode 100644 index 0000000..e158b9b --- /dev/null +++ b/onnxtr/models/classification/models/mobilenet.py @@ -0,0 +1,125 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py + +from copy import deepcopy +from typing import Any, Dict, Optional + +import numpy as np + +from ...engine import Engine + +__all__ = [ + "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", +] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "mobilenet_v3_small_crop_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 256, 256), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0", + }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 512, 512), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0", + }, +} + + +class MobileNetV3(Engine): + """MobileNetV3 Onnx loader + + Args: + ---- + model_path: path or url to onnx model file + cfg: configuration dictionary + """ + + def __init__( + self, + model_path: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + self.cfg = cfg + + def __call__( + self, + x: np.ndarray, + ) -> np.ndarray: + return self.session.run(x) + + +def _mobilenet_v3( + arch: str, + model_path: str, + **kwargs: Any, +) -> MobileNetV3: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + return MobileNetV3(model_path, cfg=_cfg, **kwargs) + + +def mobilenet_v3_small_crop_orientation( + model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"], **kwargs: Any +) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import numpy as np + >>> from onnxtr.models import mobilenet_v3_small_crop_orientation + >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) + >>> input_tensor = np.random.rand((1, 3, 256, 256)) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + MobileNetV3 + """ + return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, **kwargs) + + +def mobilenet_v3_small_page_orientation( + model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"], **kwargs: Any +) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import numpy as np + >>> from onnxtr.models import mobilenet_v3_small_page_orientation + >>> model = mobilenet_v3_small_page_orientation(pretrained=False) + >>> input_tensor = np.random.rand((1, 3, 512, 512)) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + MobileNetV3 + """ + return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, **kwargs) diff --git a/onnxtr/models/classification/predictor/__init__.py b/onnxtr/models/classification/predictor/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/models/classification/predictor/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/models/classification/predictor/base.py b/onnxtr/models/classification/predictor/base.py new file mode 100644 index 0000000..f644bf3 --- /dev/null +++ b/onnxtr/models/classification/predictor/base.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Union + +import numpy as np +from scipy.special import softmax + +from onnxtr.models.preprocessor import PreProcessor +from onnxtr.utils.repr import NestedObject + +from ...engine import Engine + +__all__ = ["OrientationPredictor"] + + +class OrientationPredictor(NestedObject): + """Implements an object able to detect the reading direction of a text box or a page. + 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core classification architecture (backbone + classification head) + """ + + _children_names: List[str] = ["pre_processor", "model"] + + def __init__( + self, + pre_processor: PreProcessor, + model: Engine, + ) -> None: + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + inputs: List[np.ndarray], + ) -> List[Union[List[int], List[float]]]: + # Dimension check + if any(input.ndim != 3 for input in inputs): + raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(inputs) + predicted_batches = [self.model(batch) for batch in processed_batches] + + # confidence + probs = [np.max(softmax(batch, axis=1), axis=1) for batch in predicted_batches] + # Postprocess predictions + predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches] + + class_idxs = [int(pred) for batch in predicted_batches for pred in batch] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] + confs = [round(float(p), 2) for prob in probs for p in prob] + + return [class_idxs, classes, confs] diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py new file mode 100644 index 0000000..35b077f --- /dev/null +++ b/onnxtr/models/classification/zoo.py @@ -0,0 +1,93 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List + +from .. import classification +from ..preprocessor import PreProcessor +from .predictor import OrientationPredictor + +__all__ = ["crop_orientation_predictor", "page_orientation_predictor"] + +ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "mobilenet_v3_small_crop_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 256, 256), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0", # TODO + }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 512, 512), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0", # TODO + }, +} + + +def _orientation_predictor(arch: str, **kwargs: Any) -> OrientationPredictor: + if arch not in ORIENTATION_ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + # Load directly classifier from backbone + _model = classification.__dict__[arch]() + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + input_shape = _model.cfg["input_shape"][1:] + predictor = OrientationPredictor( + PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model + ) + return predictor + + +def crop_orientation_predictor( + arch: Any = "mobilenet_v3_small_crop_orientation", **kwargs: Any +) -> OrientationPredictor: + """Crop orientation classification architecture. + + >>> import numpy as np + >>> from onnxtr.models import crop_orientation_predictor + >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation') + >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8) + >>> out = model([input_crop]) + + Args: + ---- + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') + **kwargs: keyword arguments to be passed to the OrientationPredictor + + Returns: + ------- + OrientationPredictor + """ + return _orientation_predictor(arch, **kwargs) + + +def page_orientation_predictor( + arch: Any = "mobilenet_v3_small_page_orientation", **kwargs: Any +) -> OrientationPredictor: + """Page orientation classification architecture. + + >>> import numpy as np + >>> from onnxtr.models import page_orientation_predictor + >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation') + >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') + **kwargs: keyword arguments to be passed to the OrientationPredictor + + Returns: + ------- + OrientationPredictor + """ + return _orientation_predictor(arch, **kwargs) diff --git a/onnxtr/models/detection/__init__.py b/onnxtr/models/detection/__init__.py new file mode 100644 index 0000000..cd47940 --- /dev/null +++ b/onnxtr/models/detection/__init__.py @@ -0,0 +1,2 @@ +from .models import * +from .zoo import * diff --git a/onnxtr/models/detection/_utils/__init__.py b/onnxtr/models/detection/_utils/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/models/detection/_utils/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/models/detection/_utils/base.py b/onnxtr/models/detection/_utils/base.py new file mode 100644 index 0000000..b4686db --- /dev/null +++ b/onnxtr/models/detection/_utils/base.py @@ -0,0 +1,41 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import cv2 +import numpy as np + +__all__ = ["erode", "dilate"] + + +def erode(x: np.ndarray, kernel_size: int) -> np.ndarray: + """Performs erosion on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for erosion + + Returns: + ------- + the eroded tensor + """ + kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) + return 1 - cv2.erode(1 - x.astype(np.uint8), kernel, iterations=1) + + +def dilate(x: np.ndarray, kernel_size: int) -> np.ndarray: + """Performs dilation on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for dilation + + Returns: + ------- + the dilated tensor + """ + kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) + return cv2.dilate(x.astype(np.uint8), kernel, iterations=1) diff --git a/onnxtr/models/detection/core.py b/onnxtr/models/detection/core.py new file mode 100644 index 0000000..ff118db --- /dev/null +++ b/onnxtr/models/detection/core.py @@ -0,0 +1,101 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List + +import cv2 +import numpy as np + +from onnxtr.utils.repr import NestedObject + +__all__ = ["DetectionPostProcessor"] + + +class DetectionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + box_thresh (float): minimal objectness score to consider a box + bin_thresh (float): threshold to apply to segmentation raw heatmap + assume straight_pages (bool): if True, fit straight boxes only + """ + + def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None: + self.box_thresh = box_thresh + self.bin_thresh = bin_thresh + self.assume_straight_pages = assume_straight_pages + self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8) + + def extra_repr(self) -> str: + return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}" + + @staticmethod + def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float: + """Compute the confidence score for a polygon : mean of the p values on the polygon + + Args: + ---- + pred (np.ndarray): p map returned by the model + points: coordinates of the polygon + assume_straight_pages: if True, fit straight boxes only + + Returns: + ------- + polygon objectness + """ + h, w = pred.shape[:2] + + if assume_straight_pages: + xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1) + return pred[ymin : ymax + 1, xmin : xmax + 1].mean() + + else: + mask: np.ndarray = np.zeros((h, w), np.int32) + cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] + product = pred * mask + return np.sum(product) / np.count_nonzero(product) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + raise NotImplementedError + + def __call__( + self, + proba_map, + ) -> List[List[np.ndarray]]: + """Performs postprocessing for a list of model outputs + + Args: + ---- + proba_map: probability map of shape (N, H, W, C) + + Returns: + ------- + list of N class predictions (for each input sample), where each class predictions is a list of C tensors + of shape (*, 5) or (*, 6) + """ + if proba_map.ndim != 4: + raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") + + # Erosion + dilation on the binary map + bin_map = [ + [ + cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) + for idx in range(proba_map.shape[-1]) + ] + for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) + ] + + return [ + [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])] + for pmaps, bmaps in zip(proba_map, bin_map) + ] diff --git a/onnxtr/models/detection/models/__init__.py b/onnxtr/models/detection/models/__init__.py new file mode 100644 index 0000000..f3d012e --- /dev/null +++ b/onnxtr/models/detection/models/__init__.py @@ -0,0 +1,3 @@ +from .fast import * +from .differentiable_binarization import * +from .linknet import * \ No newline at end of file diff --git a/onnxtr/models/detection/models/differentiable_binarization.py b/onnxtr/models/detection/models/differentiable_binarization.py new file mode 100644 index 0000000..97c6aed --- /dev/null +++ b/onnxtr/models/detection/models/differentiable_binarization.py @@ -0,0 +1,158 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import expit + +from ...engine import Engine +from ..postprocessor.base import GeneralDetectionPostProcessor + +__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "db_resnet50": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0", + }, + "db_resnet34": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet34-cb6aed9e.pt&src=0", + }, + "db_mobilenet_v3_large": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0", + }, +} + + +class DBNet(Engine): + """DBNet Onnx loader + + Args: + ---- + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + cfg: the configuration dict of the model + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.cfg = cfg + self.assume_straight_pages = assume_straight_pages + + self.postprocessor = GeneralDetectionPostProcessor( + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + + prob_map = expit(logits) + if return_model_output: + out["out_map"] = prob_map + + out["preds"] = [ + dict(zip(self.class_names, preds)) for preds in self.postprocessor(np.transpose(prob_map, (0, 2, 3, 1))) + ] + + return out + + +def _dbnet( + arch: str, + model_path: str, + **kwargs: Any, +) -> DBNet: + # Build the model + return DBNet(model_path, cfg=default_cfgs[arch], **kwargs) + + +def db_resnet34(model_path: str = default_cfgs["db_resnet34"], **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-34 backbone. + + >>> import numpy as np + >>> from onnxtr.models import db_resnet34 + >>> model = db_resnet34() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet("db_resnet34", model_path, **kwargs) + + +def db_resnet50(model_path: str = default_cfgs["db_resnet50"], **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + + >>> import numpy as np + >>> from onnxtr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet("db_resnet50", model_path, **kwargs) + + +def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"], **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a MobileNet V3 Large backbone. + + >>> import numpy as np + >>> from onnxtr.models import db_mobilenet_v3_large + >>> model = db_mobilenet_v3_large() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet("db_mobilenet_v3_large", model_path**kwargs) diff --git a/onnxtr/models/detection/models/fast.py b/onnxtr/models/detection/models/fast.py new file mode 100644 index 0000000..e52e5d4 --- /dev/null +++ b/onnxtr/models/detection/models/fast.py @@ -0,0 +1,158 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import expit + +from ...engine import Engine +from ..postprocessor.base import GeneralDetectionPostProcessor + +__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "fast_tiny": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-1acac421.pt&src=0", + }, + "fast_small": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-10952cc1.pt&src=0", + }, + "fast_base": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-688a8b34.pt&src=0", + }, +} + + +class FAST(Engine): + """FAST Onnx loader + + Args: + ---- + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + cfg: the configuration dict of the model + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.cfg = cfg + self.assume_straight_pages = assume_straight_pages + + self.postprocessor = GeneralDetectionPostProcessor( + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + + prob_map = expit(logits) + if return_model_output: + out["out_map"] = prob_map + + out["preds"] = [ + dict(zip(self.class_names, preds)) for preds in self.postprocessor(np.transpose(prob_map, (0, 2, 3, 1))) + ] + + return out + + +def _fast( + arch: str, + model_path: str, + **kwargs: Any, +) -> FAST: + # Build the model + return FAST(model_path, cfg=default_cfgs[arch], **kwargs) + + +def fast_tiny(model_path: str = default_cfgs["fast_tiny"], **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a tiny TextNet backbone. + + >>> import numpy as np + >>> from onnxtr.models import fast_tiny + >>> model = fast_tiny() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast("fast_tiny", model_path, **kwargs) + + +def fast_small(model_path: str = default_cfgs["fast_small"], **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a small TextNet backbone. + + >>> import numpy as np + >>> from onnxtr.models import fast_small + >>> model = fast_small() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast("fast_small", model_path, **kwargs) + + +def fast_base(model_path: str = default_cfgs["fast_base"], **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a base TextNet backbone. + + >>> import numpy as np + >>> from onnxtr.models import fast_base + >>> model = fast_base() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast("fast_base", model_path, **kwargs) diff --git a/onnxtr/models/detection/models/linknet.py b/onnxtr/models/detection/models/linknet.py new file mode 100644 index 0000000..64b96fe --- /dev/null +++ b/onnxtr/models/detection/models/linknet.py @@ -0,0 +1,158 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import expit + +from ...engine import Engine +from ..postprocessor.base import GeneralDetectionPostProcessor + +__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "linknet_resnet18": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0", + }, + "linknet_resnet34": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0", + }, + "linknet_resnet50": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0", + }, +} + + +class LinkNet(Engine): + """LinkNet Onnx loader + + Args: + ---- + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + cfg: the configuration dict of the model + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.cfg = cfg + self.assume_straight_pages = assume_straight_pages + + self.postprocessor = GeneralDetectionPostProcessor( + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + + prob_map = expit(logits) + if return_model_output: + out["out_map"] = prob_map + + out["preds"] = [ + dict(zip(self.class_names, preds)) for preds in self.postprocessor(np.transpose(prob_map, (0, 2, 3, 1))) + ] + + return out + + +def _linknet( + arch: str, + model_path: str, + **kwargs: Any, +) -> LinkNet: + # Build the model + return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs) + + +def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"], **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import numpy as np + >>> from onnxtr.models import linknet_resnet18 + >>> model = linknet_resnet18() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet("linknet_resnet18", model_path, **kwargs) + + +def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"], **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import numpy as np + >>> from onnxtr.models import linknet_resnet34 + >>> model = linknet_resnet34() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet("linknet_resnet34", model_path, **kwargs) + + +def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"], **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import numpy as np + >>> from onnxtr.models import linknet_resnet50 + >>> model = linknet_resnet50() + >>> input_tensor = np.random.rand(1, 3, 1024, 1024) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet("linknet_resnet50", model_path, **kwargs) diff --git a/onnxtr/models/detection/postprocessor/__init__.py b/onnxtr/models/detection/postprocessor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onnxtr/models/detection/postprocessor/base.py b/onnxtr/models/detection/postprocessor/base.py new file mode 100644 index 0000000..1e5a934 --- /dev/null +++ b/onnxtr/models/detection/postprocessor/base.py @@ -0,0 +1,144 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import List, Union + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from ..core import DetectionPostProcessor + +__all__ = ["GeneralDetectionPostProcessor"] + + +class GeneralDetectionPostProcessor(DetectionPostProcessor): + """Implements a post processor for FAST model. + + Args: + ---- + bin_thresh: threshold used to binzarized p_map at inference time + box_thresh: minimal objectness score to consider a box + assume_straight_pages: whether the inputs were expected to have horizontal text elements + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + ) -> None: + super().__init__(box_thresh, bin_thresh, assume_straight_pages) + self.unclip_ratio = 1.0 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + ---- + points: The first parameter. + + Returns: + ------- + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points: np.ndarray = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None # type: ignore[return-value] + return ( + cv2.boundingRect(expanded_points) # type: ignore[return-value] + if self.assume_straight_pages + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + ---- + pred: Pred map from differentiable linknet output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + ------- + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box + """ + height, width = bitmap.shape[:2] + boxes: List[Union[np.ndarray, List[float]]] = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + if self.assume_straight_pages: + # compute relative polygon to get rid of img shape + x, y, w, h = _box + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) diff --git a/onnxtr/models/detection/predictor/__init__.py b/onnxtr/models/detection/predictor/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/models/detection/predictor/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/models/detection/predictor/base.py b/onnxtr/models/detection/predictor/base.py new file mode 100644 index 0000000..45d61ec --- /dev/null +++ b/onnxtr/models/detection/predictor/base.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np + +from onnxtr.models.preprocessor import PreProcessor +from onnxtr.utils.repr import NestedObject + +from ...engine import Engine + +__all__ = ["DetectionPredictor"] + + +class DetectionPredictor(NestedObject): + """Implements an object able to localize text elements in a document + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + """ + + _children_names: List[str] = ["pre_processor", "model"] + + def __init__( + self, + pre_processor: PreProcessor, + model: Engine, + ) -> None: + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + pages: List[np.ndarray], + return_maps: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + predicted_batches = [ + self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs) + for batch in processed_batches + ] + + preds = [pred for batch in predicted_batches for pred in batch["preds"]] + if return_maps: + seg_maps = [pred for batch in predicted_batches for pred in batch["out_map"]] + return preds, seg_maps + return preds diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py new file mode 100644 index 0000000..f743fed --- /dev/null +++ b/onnxtr/models/detection/zoo.py @@ -0,0 +1,73 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +from .. import detection +from ..preprocessor import PreProcessor +from .predictor import DetectionPredictor + +__all__ = ["detection_predictor"] + +ARCHS = [ + "db_resnet34", + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "linknet_resnet34", + "linknet_resnet50", + "fast_tiny", + "fast_small", + "fast_base", +] + + +def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages) + else: + if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): + raise ValueError(f"unknown architecture: {type(arch)}") + + _model = arch + _model.assume_straight_pages = assume_straight_pages + + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 2) + predictor = DetectionPredictor( + PreProcessor(_model.cfg["input_shape"][1:], **kwargs), + _model, + ) + return predictor + + +def detection_predictor( + arch: Any = "db_resnet50", + assume_straight_pages: bool = True, + **kwargs: Any, +) -> DetectionPredictor: + """Text detection architecture. + + >>> import numpy as np + >>> from onnxtr.models import detection_predictor + >>> model = detection_predictor(arch='db_resnet50') + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture or model itself to use (e.g. 'db_resnet50') + assume_straight_pages: If True, fit straight boxes to the page + **kwargs: optional keyword arguments passed to the architecture + + Returns: + ------- + Detection predictor + """ + return _predictor(arch, assume_straight_pages, **kwargs) diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py new file mode 100644 index 0000000..1600a60 --- /dev/null +++ b/onnxtr/models/engine.py @@ -0,0 +1,30 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +import numpy as np +import onnxruntime + +from onnxtr.utils.data import download_from_url + + +class Engine: + """Implements an abstract class for the engine of a model + + Args: + ---- + url: the url to use to download a model if needed + **kwargs: additional arguments to be passed to `download_from_url` + """ + + def __init__(self, url: str, **kwargs: Any) -> None: + archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url + self.session = onnxruntime.InferenceSession( + archive_path, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] + ) + + def run(self, inputs: np.ndarray) -> List[np.ndarray]: + return self.session.run(["logits"], {"input": inputs}) diff --git a/onnxtr/models/predictor/__init__.py b/onnxtr/models/predictor/__init__.py new file mode 100644 index 0000000..35eae60 --- /dev/null +++ b/onnxtr/models/predictor/__init__.py @@ -0,0 +1 @@ +from .predictor import * diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py new file mode 100644 index 0000000..e5221a0 --- /dev/null +++ b/onnxtr/models/predictor/base.py @@ -0,0 +1,170 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +from onnxtr.models.builder import DocumentBuilder +from onnxtr.utils.geometry import extract_crops, extract_rcrops + +from .._utils import rectify_crops, rectify_loc_preds +from ..classification import crop_orientation_predictor +from ..classification.predictor import OrientationPredictor + +__all__ = ["_OCRPredictor"] + + +class _OCRPredictor: + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) + symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. + **kwargs: keyword args of `DocumentBuilder` + """ + + crop_orientation_predictor: Optional[OrientationPredictor] + + def __init__( + self, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + **kwargs: Any, + ) -> None: + self.assume_straight_pages = assume_straight_pages + self.straighten_pages = straighten_pages + self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) + self.doc_builder = DocumentBuilder(**kwargs) + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.hooks: List[Callable] = [] + + @staticmethod + def _generate_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> List[List[np.ndarray]]: + extraction_fn = extract_crops if assume_straight_pages else extract_rcrops + + crops = [ + extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] + for page, _boxes in zip(pages, loc_preds) + ] + return crops + + @staticmethod + def _prepare_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: + crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) + + # Avoid sending zero-sized crops + is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] + crops = [ + [crop for crop, _kept in zip(page_crops, page_kept) if _kept] + for page_crops, page_kept in zip(crops, is_kept) + ] + loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] + + return crops, loc_preds + + def _rectify_crops( + self, + crops: List[List[np.ndarray]], + loc_preds: List[np.ndarray], + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]: + # Work at a page level + orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc] + rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] + rect_loc_preds = [ + rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds + for page_loc_preds, orientation in zip(loc_preds, orientations) + ] + # Flatten to list of tuples with (value, confidence) + crop_orientations = [ + (orientation, prob) + for page_classes, page_probs in zip(classes, probs) + for orientation, prob in zip(page_classes, page_probs) + ] + return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value] + + def _remove_padding( + self, + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + ) -> List[np.ndarray]: + if self.preserve_aspect_ratio: + # Rectify loc_preds to remove padding + rectified_preds = [] + for page, loc_pred in zip(pages, loc_preds): + h, w = page.shape[0], page.shape[1] + if h > w: + # y unchanged, dilate x coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1) + else: + loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] *= h / w + else: + loc_pred[:, :, 0] *= h / w + elif w > h: + # x unchanged, dilate y coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1) + else: + loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] *= w / h + else: + loc_pred[:, :, 1] *= w / h + rectified_preds.append(loc_pred) + return rectified_preds + return loc_preds + + @staticmethod + def _process_predictions( + loc_preds: List[np.ndarray], + word_preds: List[Tuple[str, float]], + crop_orientations: List[Dict[str, Any]], + ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]: + text_preds = [] + crop_orientation_preds = [] + if len(loc_preds) > 0: + # Text & crop orientation predictions at page level + _idx = 0 + for page_boxes in loc_preds: + text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]]) + crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]]) + _idx += page_boxes.shape[0] + + return loc_preds, text_preds, crop_orientation_preds + + def add_hook(self, hook: Callable) -> None: + """Add a hook to the predictor + + Args: + ---- + hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` + """ + self.hooks.append(hook) diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py new file mode 100644 index 0000000..2f61a08 --- /dev/null +++ b/onnxtr/models/predictor/predictor.py @@ -0,0 +1,145 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +import numpy as np + +from onnxtr.io.elements import Document +from onnxtr.models._utils import estimate_orientation, get_language +from onnxtr.models.detection.predictor import DetectionPredictor +from onnxtr.models.recognition.predictor import RecognitionPredictor +from onnxtr.utils.geometry import rotate_image +from onnxtr.utils.repr import NestedObject + +from .base import _OCRPredictor + +__all__ = ["OCRPredictor"] + + +class OCRPredictor(NestedObject, _OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + **kwargs: keyword args of `DocumentBuilder` + """ + + _children_names = ["det_predictor", "reco_predictor", "doc_builder"] + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + self.det_predictor = det_predictor + self.reco_predictor = reco_predictor + _OCRPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + def __call__( + self, + pages: List[np.ndarray], + **kwargs: Any, + ) -> Document: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] for page in pages] + + # Localize text elements + loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) + + # Detect document rotation and rotate pages + seg_maps = [ + np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8) + for out_map in out_maps + ] + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(seq_map) for seq_map in seg_maps] + ) + pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] + # forward again to get predictions on straight pages + loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment] + + assert all( + len(loc_pred) == 1 for loc_pred in loc_preds_dict + ), "Detection Model in ocr_predictor should output only one class" + loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr] + + # Rectify crops if aspect ratio + loc_preds = self._remove_padding(pages, loc_preds) + + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) + + # Crop images + crops, loc_preds = self._prepare_crops( + pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages + ) + # Rectify crop orientation and get crop orientation predictions + crop_orientations: Any = [] + if not self.assume_straight_pages: + crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds) + crop_orientations = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations + ] + + # Identify character sequences + word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) + if not crop_orientations: + crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds] + + boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations) + + if self.detect_language: + languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + + out = self.doc_builder( + pages, + boxes, + text_preds, + origin_page_shapes, # type: ignore[arg-type] + crop_orientations, + orientations, + languages_dict, + ) + return out diff --git a/onnxtr/models/preprocessor/__init__.py b/onnxtr/models/preprocessor/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/models/preprocessor/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/models/preprocessor/base.py b/onnxtr/models/preprocessor/base.py new file mode 100644 index 0000000..2d0a3b1 --- /dev/null +++ b/onnxtr/models/preprocessor/base.py @@ -0,0 +1,114 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Any, List, Tuple, Union + +import cv2 +import numpy as np + +from onnxtr.transforms import Normalize, Resize +from onnxtr.utils.multithreading import multithread_exec +from onnxtr.utils.repr import NestedObject + +__all__ = ["PreProcessor"] + + +class PreProcessor(NestedObject): + """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. + + Args: + ---- + output_size: expected size of each page in format (H, W) + batch_size: the size of page batches + mean: mean value of the training distribution by channel + std: standard deviation of the training distribution by channel + """ + + _children_names: List[str] = ["resize", "normalize"] + + def __init__( + self, + output_size: Tuple[int, int], + batch_size: int, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (1.0, 1.0, 1.0), + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.resize = Resize(output_size, **kwargs) + self.normalize = Normalize(mean, std) + + def batch_inputs(self, samples: List[np.ndarray]) -> List[np.ndarray]: + """Gather samples into batches for inference purposes + + Args: + ---- + samples: list of samples (tf.Tensor) + + Returns: + ------- + list of batched samples + """ + num_batches = int(math.ceil(len(samples) / self.batch_size)) + batches = [ + np.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0) + for idx in range(int(num_batches)) + ] + + return batches + + def sample_transforms(self, x: np.ndarray) -> np.ndarray: + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + # Data type + if x.dtype == np.uint8: + x = x.astype(np.float32) + # Resizing + x = self.resize(x) + + return x + + def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]: + """Prepare document data for model forwarding + + Args: + ---- + x: list of images (np.array) already resized and batched + + Returns: + ------- + list of page batches + """ + # Input type check + if isinstance(x, np.ndarray): + if x.ndim != 4: + raise AssertionError("expected 4D Tensor") + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + + # Data type + if x.dtype == np.uint8: + x = x.astype(np.float32) + # Resizing + if (x.shape[1], x.shape[2]) != self.resize.output_size: + x = cv2.resize(x, self.resize.output_size, interpolation=self.resize.method) + batches = [x] + + elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x): + # Sample transform (resize) + samples = list(multithread_exec(self.sample_transforms, x)) + # Batching + batches = self.batch_inputs(samples) + else: + raise TypeError(f"invalid input type: {type(x)}") + + # Batch transforms (normalize) + batches = list(multithread_exec(self.normalize, batches)) + + return batches diff --git a/onnxtr/models/recognition/__init__.py b/onnxtr/models/recognition/__init__.py new file mode 100644 index 0000000..cd47940 --- /dev/null +++ b/onnxtr/models/recognition/__init__.py @@ -0,0 +1,2 @@ +from .models import * +from .zoo import * diff --git a/onnxtr/models/recognition/core.py b/onnxtr/models/recognition/core.py new file mode 100644 index 0000000..35e1c86 --- /dev/null +++ b/onnxtr/models/recognition/core.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from onnxtr.utils.repr import NestedObject + +__all__ = ["RecognitionPostProcessor"] + + +class RecognitionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + self.vocab = vocab + self._embedding = list(self.vocab) + [""] + + def extra_repr(self) -> str: + return f"vocab_size={len(self.vocab)}" diff --git a/onnxtr/models/recognition/models/__init__.py b/onnxtr/models/recognition/models/__init__.py new file mode 100644 index 0000000..ba5975e --- /dev/null +++ b/onnxtr/models/recognition/models/__init__.py @@ -0,0 +1,5 @@ +from .crnn import * +from .sar import * +from .master import * +from .vitstr import * +from .parseq import * diff --git a/onnxtr/models/recognition/models/crnn.py b/onnxtr/models/recognition/models/crnn.py new file mode 100644 index 0000000..9cb7373 --- /dev/null +++ b/onnxtr/models/recognition/models/crnn.py @@ -0,0 +1,225 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from itertools import groupby +from typing import Any, Dict, List, Optional + +import numpy as np +from scipy.special import softmax + +from onnxtr.utils import VOCABS + +from ...engine import Engine +from ..core import RecognitionPostProcessor + +__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "crnn_vgg16_bn": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["legacy_french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0", + }, + "crnn_mobilenet_v3_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small_pt-3b919a02.pt&src=0", + }, + "crnn_mobilenet_v3_large": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_large_pt-f5259ec2.pt&src=0", + }, +} + + +class CRNNPostProcessor(RecognitionPostProcessor): + """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__(self, vocab): + self.vocab = vocab + + def decode_sequence(self, sequence, vocab): + return "".join([vocab[int(char)] for char in sequence]) + + def ctc_best_path( + self, + logits, + vocab, + blank=0, + ): + """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from + `_. + + Args: + ---- + logits: model output, shape: N x T x C + vocab: vocabulary to use + blank: index of blank label + + Returns: + ------- + A list of tuples: (word, confidence) + """ + # Gather the most confident characters, and assign the smallest conf among those to the sequence prob + probs = softmax(logits, axis=-1).max(axis=-1).min(axis=1) + + # collapse best path (using itertools.groupby), map to chars, join char list to string + words = [ + self.decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) + for seq in np.argmax(logits, axis=-1) + ] + + return list(zip(words, probs.tolist())) + + def __call__(self, logits): + """Performs decoding of raw output with CTC and decoding of CTC predictions + with label_to_idx mapping dictionnary + + Args: + ---- + logits: raw output of the model, shape (N, C + 1, seq_len) + + Returns: + ------- + A tuple of 2 lists: a list of str (words) and a list of float (probs) + + """ + # Decode CTC + return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab)) + + +class CRNN(Engine): + """CRNN Onnx loader + + Args: + ---- + model_path: path or url to onnx model file + vocab: vocabulary used for encoding + cfg: configuration dictionary + """ + + _children_names: List[str] = ["postprocessor"] + + def __init__( + self, + model_path: str, + vocab: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + self.vocab = vocab + self.cfg = cfg + self.postprocessor = CRNNPostProcessor(self.vocab) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + if return_model_output: + out["out_map"] = logits + + # Post-process + out["preds"] = self.postprocessor(logits) + + return out + + +def _crnn( + arch: str, + model_path: str, + **kwargs: Any, +) -> CRNN: + kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"]) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs["vocab"] + _cfg["input_shape"] = kwargs["input_shape"] + + # Build the model + return CRNN(model_path, cfg=_cfg, **kwargs) + + +def crnn_vgg16_bn(model_path: str = default_cfgs["crnn_vgg16_bn"], **kwargs: Any) -> CRNN: + """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import numpy as np + >>> from onnxtr.models import crnn_vgg16_bn + >>> model = crnn_vgg16_bn() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_vgg16_bn", model_path, **kwargs) + + +def crnn_mobilenet_v3_small(model_path: str = default_cfgs["crnn_mobilenet_v3_small"], **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import numpy as np + >>> from onnxtr.models import crnn_mobilenet_v3_small + >>> model = crnn_mobilenet_v3_small() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_mobilenet_v3_small", model_path, **kwargs) + + +def crnn_mobilenet_v3_large(model_path: str = default_cfgs["crnn_mobilenet_v3_large"], **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import numpy as np + >>> from onnxtr.models import crnn_mobilenet_v3_large + >>> model = crnn_mobilenet_v3_large() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_mobilenet_v3_large", model_path, **kwargs) diff --git a/onnxtr/models/recognition/models/master.py b/onnxtr/models/recognition/models/master.py new file mode 100644 index 0000000..76c30ef --- /dev/null +++ b/onnxtr/models/recognition/models/master.py @@ -0,0 +1,144 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from scipy.special import softmax + +from onnxtr.utils import VOCABS + +from ...engine import Engine +from ..core import RecognitionPostProcessor + +__all__ = ["MASTER", "master"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "master": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0", + }, +} + + +class MASTER(Engine): + """MASTER Onnx loader + + Args: + ---- + model_path: path or url to onnx model file + vocab: vocabulary, (without EOS, SOS, PAD) + cfg: dictionary containing information about the model + """ + + def __init__( + self, + model_path: str, + vocab: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + + self.vocab = vocab + self.cfg = cfg + self.postprocessor = MASTERPostProcessor(vocab=self.vocab) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + ) -> Dict[str, Any]: + """Call function + + Args: + ---- + x: images + return_model_output: if True, return logits + + Returns: + ------- + A dictionnary containing eventually logits and predictions. + """ + logits = self.session.run(x) + out: Dict[str, Any] = {} + + if return_model_output: + out["out_map"] = logits + + out["preds"] = self.postprocessor(logits) + + return out + + +class MASTERPostProcessor(RecognitionPostProcessor): + """Post-processor for the MASTER model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + [""] + [""] + [""] + + def __call__(self, logits: np.ndarray) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = np.argmax(logits, axis=-1) + # N x L + probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1) + # Take the minimum confidence of the sequence + probs = np.min(probs, axis=1) + + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs + ] + + return list(zip(word_values, np.clip(probs, 0, 1).tolist())) + + +def _master( + arch: str, + model_path: str, + **kwargs: Any, +) -> MASTER: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + return MASTER(model_path, cfg=_cfg, **kwargs) + + +def master(model_path: str = default_cfgs["master"], **kwargs: Any) -> MASTER: + """MASTER as described in paper: `_. + + >>> import numpy as np + >>> from onnxtr.models import master + >>> model = master() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keywoard arguments passed to the MASTER architecture + + Returns: + ------- + text recognition architecture + """ + return _master("master", model_path, **kwargs) diff --git a/onnxtr/models/recognition/models/parseq.py b/onnxtr/models/recognition/models/parseq.py new file mode 100644 index 0000000..95729d4 --- /dev/null +++ b/onnxtr/models/recognition/models/parseq.py @@ -0,0 +1,130 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import softmax + +from onnxtr.utils import VOCABS + +from ...engine import Engine +from ..core import RecognitionPostProcessor + +__all__ = ["PARSeq", "parseq"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "parseq": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/parseq-56125471.pt&src=0", + }, +} + + +class PARSeq(Engine): + """PARSeq Onnx loader + + Args: + ---- + vocab: vocabulary used for encoding + cfg: dictionary containing information about the model + """ + + def __init__( + self, + model_path: str, + vocab: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + self.vocab = vocab + self.cfg = cfg + self.postprocessor = PARSeqPostProcessor(vocab=self.vocab) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + ) -> Dict[str, Any]: + logits = self.session.run(x) + out: Dict[str, Any] = {} + + if return_model_output: + out["out_map"] = logits + + out["preds"] = self.postprocessor(logits) + return out + + +class PARSeqPostProcessor(RecognitionPostProcessor): + """Post processor for PARSeq architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + ["", "", ""] + + def __call__(self, logits): + # compute pred with argmax for attention models + out_idxs = np.argmax(logits, axis=-1) + preds_prob = softmax(logits, axis=-1).max(axis=-1) + + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs + ] + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].clip(0, 1).mean() if word else 0.0 for i, word in enumerate(word_values)] + + return list(zip(word_values, probs)) + + +def _parseq( + arch: str, + model_path: str, + **kwargs: Any, +) -> PARSeq: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + return PARSeq(model_path, cfg=_cfg, **kwargs) + + +def parseq(model_path: str = default_cfgs["parseq"], **kwargs: Any) -> PARSeq: + """PARSeq architecture from + `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_. + + >>> import numpy as np + >>> from onnxtr.models import parseq + >>> model = parseq() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the PARSeq architecture + + Returns: + ------- + text recognition architecture + """ + return _parseq("parseq", model_path, **kwargs) diff --git a/onnxtr/models/recognition/models/sar.py b/onnxtr/models/recognition/models/sar.py new file mode 100644 index 0000000..ed537e9 --- /dev/null +++ b/onnxtr/models/recognition/models/sar.py @@ -0,0 +1,133 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import softmax + +from onnxtr.utils import VOCABS + +from ...engine import Engine +from ..core import RecognitionPostProcessor + +__all__ = ["SAR", "sar_resnet31"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "sar_resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0", + }, +} + + +class SAR(Engine): + """SAR Onnx loader + + Args: + ---- + model_path: path to onnx model file + vocab: vocabulary used for encoding + cfg: dictionary containing information about the model + """ + + def __init__( + self, + model_path: str, + vocab: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + self.vocab = vocab + self.cfg = cfg + self.postprocessor = SARPostProcessor(self.vocab) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + if return_model_output: + out["out_map"] = logits + + out["preds"] = self.postprocessor(logits) + + return out + + +class SARPostProcessor(RecognitionPostProcessor): + """Post processor for SAR architectures + + Args: + ---- + embedding: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(self.vocab) + [""] + + def __call__(self, logits): + # compute pred with argmax for attention models + out_idxs = np.argmax(logits, axis=-1) + # N x L + probs = np.take_along_axis(softmax(logits, axis=-1), out_idxs[..., None], axis=-1).squeeze(-1) + # Take the minimum confidence of the sequence + probs = np.min(probs, axis=1) + + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs + ] + + return list(zip(word_values, np.clip(probs, 0, 1).tolist())) + + +def _sar( + arch: str, + model_path: str, + **kwargs: Any, +) -> SAR: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + return SAR(model_path, cfg=_cfg, **kwargs) + + +def sar_resnet31(model_path: str = default_cfgs["sar_resnet31"], **kwargs: Any) -> SAR: + """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong + Baseline for Irregular Text Recognition" `_. + + >>> import numpy as np + >>> from onnxtr.models import sar_resnet31 + >>> model = sar_resnet31() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + **kwargs: keyword arguments of the SAR architecture + + Returns: + ------- + text recognition architecture + """ + return _sar("sar_resnet31", model_path, **kwargs) diff --git a/onnxtr/models/recognition/models/vitstr.py b/onnxtr/models/recognition/models/vitstr.py new file mode 100644 index 0000000..7b08234 --- /dev/null +++ b/onnxtr/models/recognition/models/vitstr.py @@ -0,0 +1,162 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, Optional + +import numpy as np +from scipy.special import softmax + +from onnxtr.utils import VOCABS + +from ...engine import Engine +from ..core import RecognitionPostProcessor + +__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vitstr_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_small-fcd12655.pt&src=0", + }, + "vitstr_base": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_base-50b21df2.pt&src=0", + }, +} + + +class ViTSTR(Engine): + """ViTSTR Onnx loader + + Args: + ---- + model_path: path to onnx model file + vocab: vocabulary used for encoding + cfg: dictionary containing information about the model + """ + + def __init__( + self, + model_path: str, + vocab: str, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(url=model_path) + self.vocab = vocab + self.cfg = cfg + + self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab) + + def __call__( + self, + x: np.ndarray, + return_model_output: bool = False, + ) -> Dict[str, Any]: + logits = self.session.run(x) + + out: Dict[str, Any] = {} + if return_model_output: + out["out_map"] = logits + + out["preds"] = self.postprocessor(logits) + + return out + + +class ViTSTRPostProcessor(RecognitionPostProcessor): + """Post processor for ViTSTR architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + ["", ""] + + def __call__(self, logits): + # compute pred with argmax for attention models + out_idxs = np.argmax(logits, axis=-1) + preds_prob = softmax(logits, axis=-1).max(axis=-1) + + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs + ] + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].clip(0, 1).mean() if word else 0.0 for i, word in enumerate(word_values)] + + return list(zip(word_values, probs)) + + +def _vitstr( + arch: str, + model_path: str, + **kwargs: Any, +) -> ViTSTR: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + return ViTSTR(model_path, cfg=_cfg, **kwargs) + + +def vitstr_small(model_path: str = default_cfgs["vitstr_small"], **kwargs: Any) -> ViTSTR: + """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import numpy as np + >>> from onnxtr.models import vitstr_small + >>> model = vitstr_small() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr("vitstr_small", model_path**kwargs) + + +def vitstr_base(model_path: str = default_cfgs["vitstr_base"], **kwargs: Any) -> ViTSTR: + """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import numpy as np + >>> from onnxtr.models import vitstr_base + >>> model = vitstr_base() + >>> input_tensor = np.random.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + model_path: path to onnx model file, defaults to url in default_cfgs + kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr("vitstr_base", model_path, **kwargs) diff --git a/onnxtr/models/recognition/predictor/__init__.py b/onnxtr/models/recognition/predictor/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/models/recognition/predictor/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/models/recognition/predictor/_utils.py b/onnxtr/models/recognition/predictor/_utils.py new file mode 100644 index 0000000..4998556 --- /dev/null +++ b/onnxtr/models/recognition/predictor/_utils.py @@ -0,0 +1,86 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple, Union + +import numpy as np + +from ..utils import merge_multi_strings + +__all__ = ["split_crops", "remap_preds"] + + +def split_crops( + crops: List[np.ndarray], + max_ratio: float, + target_ratio: int, + dilation: float, + channels_last: bool = True, +) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: + """Chunk crops horizontally to match a given aspect ratio + + Args: + ---- + crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise + max_ratio: the maximum aspect ratio that won't trigger the chunk + target_ratio: when crops are chunked, they will be chunked to match this aspect ratio + dilation: the width dilation of final chunks (to provide some overlaps) + channels_last: whether the numpy array has dimensions in channels last order + + Returns: + ------- + a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required + """ + _remap_required = False + crop_map: List[Union[int, Tuple[int, int]]] = [] + new_crops: List[np.ndarray] = [] + for crop in crops: + h, w = crop.shape[:2] if channels_last else crop.shape[-2:] + aspect_ratio = w / h + if aspect_ratio > max_ratio: + # Determine the number of crops, reference aspect ratio = 4 = 128 / 32 + num_subcrops = int(aspect_ratio // target_ratio) + # Find the new widths, additional dilation factor to overlap crops + width = dilation * w / num_subcrops + centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] + # Get the crops + if channels_last: + _crops = [ + crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :] + for center in centers + ] + else: + _crops = [ + crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))] + for center in centers + ] + # Avoid sending zero-sized crops + _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] + # Record the slice of crops + crop_map.append((len(new_crops), len(new_crops) + len(_crops))) + new_crops.extend(_crops) + # At least one crop will require merging + _remap_required = True + else: + crop_map.append(len(new_crops)) + new_crops.append(crop) + + return new_crops, crop_map, _remap_required + + +def remap_preds( + preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float +) -> List[Tuple[str, float]]: + remapped_out = [] + for _idx in crop_map: + # Crop hasn't been split + if isinstance(_idx, int): + remapped_out.append(preds[_idx]) + else: + # unzip + vals, probs = zip(*preds[_idx[0] : _idx[1]]) + # Merge the string values + remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type] + return remapped_out diff --git a/onnxtr/models/recognition/predictor/base.py b/onnxtr/models/recognition/predictor/base.py new file mode 100644 index 0000000..b9838d3 --- /dev/null +++ b/onnxtr/models/recognition/predictor/base.py @@ -0,0 +1,80 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Sequence, Tuple + +import numpy as np + +from onnxtr.models.engine import Engine +from onnxtr.models.preprocessor import PreProcessor +from onnxtr.utils.repr import NestedObject + +from ._utils import remap_preds, split_crops + +__all__ = ["RecognitionPredictor"] + + +class RecognitionPredictor(NestedObject): + """Implements an object able to identify character sequences in images + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core recognition architecture + split_wide_crops: wether to use crop splitting for high aspect ratio crops + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: Engine, + split_wide_crops: bool = True, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model + self.split_wide_crops = split_wide_crops + self.critical_ar = 8 # Critical aspect ratio + self.dil_factor = 1.4 # Dilation factor to overlap the crops + self.target_ar = 6 # Target aspect ratio + + def __call__( + self, + crops: Sequence[np.ndarray], + **kwargs: Any, + ) -> List[Tuple[str, float]]: + if len(crops) == 0: + return [] + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + # Split crops that are too wide + remapped = False + if self.split_wide_crops: + new_crops, crop_map, remapped = split_crops( + crops, # type: ignore[arg-type] + self.critical_ar, + self.target_ar, + self.dil_factor, + True, + ) + if remapped: + crops = new_crops + + # Resize & batch them + processed_batches = self.pre_processor(crops) + + # Forward it + raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches] + + # Process outputs + out = [charseq for batch in raw for charseq in batch] + + # Remap crops + if self.split_wide_crops and remapped: + out = remap_preds(out, crop_map, self.dil_factor) + + return out diff --git a/onnxtr/models/recognition/utils.py b/onnxtr/models/recognition/utils.py new file mode 100644 index 0000000..99b6307 --- /dev/null +++ b/onnxtr/models/recognition/utils.py @@ -0,0 +1,89 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List + +from rapidfuzz.distance import Levenshtein + +__all__ = ["merge_strings", "merge_multi_strings"] + + +def merge_strings(a: str, b: str, dil_factor: float) -> str: + """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. + + Args: + ---- + a: first char seq, suffix should be similar to b's prefix. + b: second char seq, prefix should be similar to a's suffix. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + ------- + A merged character sequence. + + Example:: + >>> from onnxtr.model.recognition.utils import merge_sequences + >>> merge_sequences('abcd', 'cdefgh', 1.4) + 'abcdefgh' + >>> merge_sequences('abcdi', 'cdefgh', 1.4) + 'abcdefgh' + """ + seq_len = min(len(a), len(b)) + if seq_len == 0: # One sequence is empty, return the other + return b if len(a) == 0 else a + + # Initialize merging index and corresponding score (mean Levenstein) + min_score, index = 1.0, 0 # No overlap, just concatenate + + scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)] + + # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0 + if len(scores) > 1 and (scores[0], scores[1]) == (0, 0): + # Compute n_overlap (number of overlapping chars, geometrically determined) + n_overlap = round(len(b) * (dil_factor - 1) / dil_factor) + # Find the number of consecutive zeros in the scores list + # Impossible to have a zero after a non-zero score in that case + n_zeros = sum(val == 0 for val in scores) + # Index is bounded by the geometrical overlap to avoid collapsing repetitions + min_score, index = 0, min(n_zeros, n_overlap) + + else: # Common case: choose the min score index + for i, score in enumerate(scores): + if score < min_score: + min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char + + # Merge with correct overlap + if index == 0: + return a + b + return a[:-1] + b[index - 1 :] + + +def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: + """Recursively merges consecutive string sequences with overlapping characters. + + Args: + ---- + seq_list: list of sequences to merge. Sequences need to be ordered from left to right. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + ------- + A merged character sequence + + Example:: + >>> from onnxtr.model.recognition.utils import merge_multi_sequences + >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4) + 'abcdefghijkl' + """ + + def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str: + # Recursive version of compute_overlap + if len(seq_list) == 1: + return merge_strings(a, seq_list[0], dil_factor) + return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor) + + return _recursive_merge("", seq_list, dil_factor) diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py new file mode 100644 index 0000000..98f0f21 --- /dev/null +++ b/onnxtr/models/recognition/zoo.py @@ -0,0 +1,69 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +from onnxtr.models.preprocessor import PreProcessor + +from .. import recognition +from .predictor import RecognitionPredictor + +__all__ = ["recognition_predictor"] + + +ARCHS: List[str] = [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", +] + + +def _predictor(arch: Any, **kwargs: Any) -> RecognitionPredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = recognition.__dict__[arch]() + else: + if not isinstance( + arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) + ): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch + + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 128) + input_shape = _model.cfg["input_shape"][:2] + predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) + + return predictor + + +def recognition_predictor(arch: Any = "crnn_vgg16_bn", **kwargs: Any) -> RecognitionPredictor: + """Text recognition architecture. + + Example:: + >>> import numpy as np + >>> from onnxtr.models import recognition_predictor + >>> model = recognition_predictor() + >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') + **kwargs: optional parameters to be passed to the architecture + + Returns: + ------- + Recognition predictor + """ + return _predictor(arch, **kwargs) diff --git a/onnxtr/models/zoo.py b/onnxtr/models/zoo.py new file mode 100644 index 0000000..c2d5ad6 --- /dev/null +++ b/onnxtr/models/zoo.py @@ -0,0 +1,114 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +from .detection.zoo import detection_predictor +from .predictor import OCRPredictor +from .recognition.zoo import recognition_predictor + +__all__ = ["ocr_predictor"] + + +def _predictor( + det_arch: Any, + reco_arch: Any, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + det_bs: int = 2, + reco_bs: int = 128, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs, +) -> OCRPredictor: + # Detection + det_predictor = detection_predictor( + det_arch, + batch_size=det_bs, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + ) + + # Recognition + reco_predictor = recognition_predictor( + reco_arch, + batch_size=reco_bs, + ) + + return OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) + + +def ocr_predictor( + det_arch: Any = "db_resnet50", + reco_arch: Any = "crnn_vgg16_bn", + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + export_as_straight_boxes: bool = False, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs: Any, +) -> OCRPredictor: + """End-to-end OCR architecture using one model for localization, and another for text recognition. + + >>> import numpy as np + >>> from onnxtr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn') + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + det_arch: name of the detection architecture or the model itself to use + (e.g. 'db_resnet50', 'db_mobilenet_v3_large') + reco_arch: name of the recognition architecture or the model itself to use + (e.g. 'crnn_vgg16_bn', 'sar_resnet31') + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it. + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + straighten_pages: if True, estimates the page general orientation + based on the segmentation map median line orientation. + Then, rotates page before passing it again to the deep learning detection module. + Doing so will improve performances for documents with page-uniform rotations. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `OCRPredictor` + + Returns: + ------- + OCR predictor + """ + return _predictor( + det_arch, + reco_arch, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + export_as_straight_boxes=export_as_straight_boxes, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) diff --git a/onnxtr/transforms/__init__.py b/onnxtr/transforms/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/onnxtr/transforms/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/onnxtr/transforms/base.py b/onnxtr/transforms/base.py new file mode 100644 index 0000000..854ff9f --- /dev/null +++ b/onnxtr/transforms/base.py @@ -0,0 +1,102 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Tuple, Union + +import cv2 +import numpy as np + +__all__ = ["Resize", "Normalize"] + + +class Resize: + """Resize the input image to the given size""" + + def __init__( + self, + size: Union[int, Tuple[int, int]], + interpolation=cv2.INTER_LINEAR, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = False, + ) -> None: + super().__init__() + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.interpolation = interpolation + + if not isinstance(size, (int, tuple, list)): + raise AssertionError("size should be either a tuple, a list or an int") + self.size = size + + def __call__( + self, + img: np.ndarray, + ) -> np.ndarray: + if isinstance(self.size, int): + target_ratio = img.shape[1] / img.shape[0] + else: + target_ratio = self.size[0] / self.size[1] + actual_ratio = img.shape[1] / img.shape[0] + + # Resize + if isinstance(self.size, (tuple, list)): + if actual_ratio > target_ratio: + tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) + else: + tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + elif isinstance(self.size, int): # self.size is the longest side, infer the other + if img.shape[0] <= img.shape[1]: + tmp_size = (max(int(self.size * actual_ratio), 1), self.size) + else: + tmp_size = (self.size, max(int(self.size / actual_ratio), 1)) + + # Scale image + img = cv2.resize(img, tmp_size, interpolation=self.interpolation) + + if isinstance(self.size, (tuple, list)): + # Pad + _pad = (0, self.size[1] - img.shape[0], 0, self.size[0] - img.shape[1]) + if self.symmetric_pad: + half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) + _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) + img = np.pad(img, ((_pad[0], _pad[1]), (_pad[2], _pad[3]), (0, 0)), mode="constant") + + return img + + def __repr__(self) -> str: + interpolate_str = self.interpolation.value + _repr = f"output_size={self.size}, interpolation='{interpolate_str}'" + if self.preserve_aspect_ratio: + _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" + return f"{self.__class__.__name__}({_repr})" + + +class Normalize: + """Normalize the input image""" + + def __init__( + self, + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406), + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225), + ) -> None: + self.mean = mean + self.std = std + + if not isinstance(self.mean, (float, tuple, list)): + raise AssertionError("mean should be either a tuple, a list or a float") + if not isinstance(self.std, (float, tuple, list)): + raise AssertionError("std should be either a tuple, a list or a float") + + def __call__( + self, + img: np.ndarray, + ) -> np.ndarray: + # Normalize image + return (img - self.mean) / self.std + + def __repr__(self) -> str: + _repr = f"mean={self.mean}, std={self.std}" + return f"{self.__class__.__name__}({_repr})" diff --git a/onnxtr/utils/__init__.py b/onnxtr/utils/__init__.py new file mode 100644 index 0000000..f3aa32a --- /dev/null +++ b/onnxtr/utils/__init__.py @@ -0,0 +1,4 @@ +from .common_types import * +from .data import * +from .geometry import * +from .vocabs import * diff --git a/onnxtr/utils/common_types.py b/onnxtr/utils/common_types.py new file mode 100644 index 0000000..eb9a28a --- /dev/null +++ b/onnxtr/utils/common_types.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Tuple, Union + +__all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"] + + +Point2D = Tuple[float, float] +BoundingBox = Tuple[Point2D, Point2D] +Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D] +Polygon = List[Point2D] +AbstractPath = Union[str, Path] +AbstractFile = Union[AbstractPath, bytes] +Bbox = Tuple[float, float, float, float] diff --git a/onnxtr/utils/data.py b/onnxtr/utils/data.py new file mode 100644 index 0000000..80e3646 --- /dev/null +++ b/onnxtr/utils/data.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py + +import hashlib +import logging +import os +import re +import urllib +import urllib.error +import urllib.request +from pathlib import Path +from typing import Optional, Union + +from tqdm.auto import tqdm + +__all__ = ["download_from_url"] + + +# matches bfd8deac from resnet18-bfd8deac.ckpt +HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") +USER_AGENT = "felixdittrich92/OnnxTR" + + +def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool: + with open(file_path, "rb") as f: + sha_hash = hashlib.sha256(f.read()).hexdigest() + + return sha_hash[: len(hash_prefix)] == hash_prefix + + +def download_from_url( + url: str, + file_name: Optional[str] = None, + hash_prefix: Optional[str] = None, + cache_dir: Optional[str] = None, + cache_subdir: Optional[str] = None, +) -> Path: + """Download a file using its URL + + >>> from onnxtr.models import download_from_url + >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + ---- + url: the URL of the file to download + file_name: optional name of the file once downloaded + hash_prefix: optional expected SHA256 hash of the file + cache_dir: cache directory + cache_subdir: subfolder to use in the cache + + Returns: + ------- + the location of the downloaded file + + Note: + ---- + You can change cache directory location by using `ONNXTR_CACHE_DIR` environment variable. + """ + if not isinstance(file_name, str): + file_name = url.rpartition("/")[-1].split("&")[0] + + cache_dir = ( + str(os.environ.get("ONNXTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "onnxtr"))) + if cache_dir is None + else cache_dir + ) + + # Check hash in file name + if hash_prefix is None: + r = HASH_REGEX.search(file_name) + hash_prefix = r.group(1) if r else None + + folder_path = Path(cache_dir) if cache_subdir is None else Path(cache_dir, cache_subdir) + file_path = folder_path.joinpath(file_name) + # Check file existence + if file_path.is_file() and (hash_prefix is None or _check_integrity(file_path, hash_prefix)): + logging.info(f"Using downloaded & verified file: {file_path}") + return file_path + + try: + # Create folder hierarchy + folder_path.mkdir(parents=True, exist_ok=True) + except OSError: + error_message = f"Failed creating cache direcotry at {folder_path}" + if os.environ.get("ONNXTR_CACHE_DIR", ""): + error_message += " using path from 'ONNXTR_CACHE_DIR' environment variable." + else: + error_message += ( + ". You can change default cache directory using 'ONNXTR_CACHE_DIR' environment variable if needed." + ) + logging.error(error_message) + raise + # Download the file + try: + print(f"Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + except (urllib.error.URLError, IOError) as e: + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + else: + raise e + + # Remove corrupted files + if isinstance(hash_prefix, str) and not _check_integrity(file_path, hash_prefix): + # Remove file + os.remove(file_path) + raise ValueError(f"corrupted download, the hash of {url} does not match its expected value") + + return file_path diff --git a/onnxtr/utils/fonts.py b/onnxtr/utils/fonts.py new file mode 100644 index 0000000..c3e86c1 --- /dev/null +++ b/onnxtr/utils/fonts.py @@ -0,0 +1,41 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +import platform +from typing import Optional + +from PIL import ImageFont + +__all__ = ["get_font"] + + +def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont: + """Resolves a compatible ImageFont for the system + + Args: + ---- + font_family: the font family to use + font_size: the size of the font upon rendering + + Returns: + ------- + the Pillow font + """ + # Font selection + if font_family is None: + try: + font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) + except OSError: + font = ImageFont.load_default() + logging.warning( + "unable to load recommended font family. Loading default PIL font," + "font size issues may be expected." + "To prevent this, it is recommended to specify the value of 'font_family'." + ) + else: + font = ImageFont.truetype(font_family, font_size) + + return font diff --git a/onnxtr/utils/geometry.py b/onnxtr/utils/geometry.py new file mode 100644 index 0000000..9c77de1 --- /dev/null +++ b/onnxtr/utils/geometry.py @@ -0,0 +1,456 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from math import ceil +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np + +from .common_types import BoundingBox, Polygon4P + +__all__ = [ + "bbox_to_polygon", + "polygon_to_bbox", + "resolve_enclosing_bbox", + "resolve_enclosing_rbbox", + "rotate_boxes", + "compute_expanded_shape", + "rotate_image", + "estimate_page_angle", + "convert_to_relative_coords", + "rotate_abs_geoms", + "extract_crops", + "extract_rcrops", +] + + +def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: + """Convert a bounding box to a polygon + + Args: + ---- + bbox: a bounding box + + Returns: + ------- + a polygon + """ + return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] + + +def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: + """Convert a polygon to a bounding box + + Args: + ---- + polygon: a polygon + + Returns: + ------- + a bounding box + """ + x, y = zip(*polygon) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]: + """Compute enclosing bbox either from: + + Args: + ---- + bboxes: boxes in one of the following formats: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + + Returns: + ------- + a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + if isinstance(bboxes, np.ndarray): + xmin, ymin, xmax, ymax, score = np.split(bboxes, 5, axis=1) + return np.array([xmin.min(), ymin.min(), xmax.max(), ymax.max(), score.mean()]) + else: + x, y = zip(*[point for box in bboxes for point in box]) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024) -> np.ndarray: + """Compute enclosing rotated bbox either from: + + Args: + ---- + rbboxes: boxes in one of the following formats: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + intermed_size: size of the intermediate image + + Returns: + ------- + a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + cloud: np.ndarray = np.concatenate(rbboxes, axis=0) + # Convert to absolute for minAreaRect + cloud *= intermed_size + rect = cv2.minAreaRect(cloud.astype(np.int32)) + return cv2.boxPoints(rect) / intermed_size # type: ignore[operator] + + +def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray: + """Rotate points counter-clockwise. + + Args: + ---- + points: array of size (N, 2) + angle: angle between -90 and +90 degrees + + Returns: + ------- + Rotated points + """ + angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions + rotation_mat = np.array( + [[np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)]], dtype=points.dtype + ) + return np.matmul(points, rotation_mat.T) + + +def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]: + """Compute the shape of an expanded rotated image + + Args: + ---- + img_shape: the height and width of the image + angle: angle between -90 and +90 degrees + + Returns: + ------- + the height and width of the rotated image + """ + points: np.ndarray = np.array([ + [img_shape[1] / 2, img_shape[0] / 2], + [-img_shape[1] / 2, img_shape[0] / 2], + ]) + + rotated_points = rotate_abs_points(points, angle) + + wh_shape = 2 * np.abs(rotated_points).max(axis=0) + return wh_shape[1], wh_shape[0] + + +def rotate_abs_geoms( + geoms: np.ndarray, + angle: float, + img_shape: Tuple[int, int], + expand: bool = True, +) -> np.ndarray: + """Rotate a batch of bounding boxes or polygons by an angle around the + image center. + + Args: + ---- + geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes + angle: anti-clockwise rotation angle in degrees + img_shape: the height and width of the image + expand: whether the image should be padded to avoid information loss + + Returns: + ------- + A batch of rotated polygons (N, 4, 2) + """ + # Switch to polygons + polys = ( + np.stack([geoms[:, [0, 1]], geoms[:, [2, 1]], geoms[:, [2, 3]], geoms[:, [0, 3]]], axis=1) + if geoms.ndim == 2 + else geoms + ) + polys = polys.astype(np.float32) + + # Switch to image center as referential + polys[..., 0] -= img_shape[1] / 2 + polys[..., 1] = img_shape[0] / 2 - polys[..., 1] + + # Rotated them around image center + rotated_polys = rotate_abs_points(polys.reshape(-1, 2), angle).reshape(-1, 4, 2) + # Switch back to top-left corner as referential + target_shape = compute_expanded_shape(img_shape, angle) if expand else img_shape + # Clip coords to fit since there is no expansion + rotated_polys[..., 0] = (rotated_polys[..., 0] + target_shape[1] / 2).clip(0, target_shape[1]) + rotated_polys[..., 1] = (target_shape[0] / 2 - rotated_polys[..., 1]).clip(0, target_shape[0]) + + return rotated_polys + + +def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: Tuple[int, int]) -> np.ndarray: + """Remaps a batch of rotated locpred (N, 4, 2) expressed for an origin_shape to a destination_shape. + This does not impact the absolute shape of the boxes, but allow to calculate the new relative RotatedBbox + coordinates after a resizing of the image. + + Args: + ---- + loc_preds: (N, 4, 2) array of RELATIVE loc_preds + orig_shape: shape of the origin image + dest_shape: shape of the destination image + + Returns: + ------- + A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial + """ + if len(dest_shape) != 2: + raise ValueError(f"Mask length should be 2, was found at: {len(dest_shape)}") + if len(orig_shape) != 2: + raise ValueError(f"Image_shape length should be 2, was found at: {len(orig_shape)}") + orig_height, orig_width = orig_shape + dest_height, dest_width = dest_shape + mboxes = loc_preds.copy() + mboxes[:, :, 0] = ((loc_preds[:, :, 0] * orig_width) + (dest_width - orig_width) / 2) / dest_width + mboxes[:, :, 1] = ((loc_preds[:, :, 1] * orig_height) + (dest_height - orig_height) / 2) / dest_height + + return mboxes + + +def rotate_boxes( + loc_preds: np.ndarray, + angle: float, + orig_shape: Tuple[int, int], + min_angle: float = 1.0, + target_shape: Optional[Tuple[int, int]] = None, +) -> np.ndarray: + """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes + (4, 2) of an angle, if angle > min_angle, around the center of the page. + If target_shape is specified, the boxes are remapped to the target shape after the rotation. This + is done to remove the padding that is created by rotate_page(expand=True) + + Args: + ---- + loc_preds: (N, 5) or (N, 4, 2) array of RELATIVE boxes + angle: angle between -90 and +90 degrees + orig_shape: shape of the origin image + min_angle: minimum angle to rotate boxes + target_shape: shape of the destination image + + Returns: + ------- + A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes + """ + # Change format of the boxes to rotated boxes + _boxes = loc_preds.copy() + if _boxes.ndim == 2: + _boxes = np.stack( + [ + _boxes[:, [0, 1]], + _boxes[:, [2, 1]], + _boxes[:, [2, 3]], + _boxes[:, [0, 3]], + ], + axis=1, + ) + # If small angle, return boxes (no rotation) + if abs(angle) < min_angle or abs(angle) > 90 - min_angle: + return _boxes + # Compute rotation matrix + angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions + rotation_mat = np.array( + [[np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)]], dtype=_boxes.dtype + ) + # Rotate absolute points + points: np.ndarray = np.stack((_boxes[:, :, 0] * orig_shape[1], _boxes[:, :, 1] * orig_shape[0]), axis=-1) + image_center = (orig_shape[1] / 2, orig_shape[0] / 2) + rotated_points = image_center + np.matmul(points - image_center, rotation_mat) + rotated_boxes: np.ndarray = np.stack( + (rotated_points[:, :, 0] / orig_shape[1], rotated_points[:, :, 1] / orig_shape[0]), axis=-1 + ) + + # Apply a mask if requested + if target_shape is not None: + rotated_boxes = remap_boxes(rotated_boxes, orig_shape=orig_shape, dest_shape=target_shape) + + return rotated_boxes + + +def rotate_image( + image: np.ndarray, + angle: float, + expand: bool = False, + preserve_origin_shape: bool = False, +) -> np.ndarray: + """Rotate an image counterclockwise by an given angle. + + Args: + ---- + image: numpy tensor to rotate + angle: rotation angle in degrees, between -90 and +90 + expand: whether the image should be padded before the rotation + preserve_origin_shape: if expand is set to True, resizes the final output to the original image size + + Returns: + ------- + Rotated array, padded by 0 by default. + """ + # Compute the expanded padding + exp_img: np.ndarray + if expand: + exp_shape = compute_expanded_shape(image.shape[:2], angle) # type: ignore[arg-type] + h_pad, w_pad = ( + int(max(0, ceil(exp_shape[0] - image.shape[0]))), + int(max(0, ceil(exp_shape[1] - image.shape[1]))), + ) + exp_img = np.pad(image, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + else: + exp_img = image + + height, width = exp_img.shape[:2] + rot_mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) + rot_img = cv2.warpAffine(exp_img, rot_mat, (width, height)) + if expand: + # Pad to get the same aspect ratio + if (image.shape[0] / image.shape[1]) != (rot_img.shape[0] / rot_img.shape[1]): + # Pad width + if (rot_img.shape[0] / rot_img.shape[1]) > (image.shape[0] / image.shape[1]): + h_pad, w_pad = 0, int(rot_img.shape[0] * image.shape[1] / image.shape[0] - rot_img.shape[1]) + # Pad height + else: + h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0 + rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + if preserve_origin_shape: + # rescale + rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR) + + return rot_img + + +def estimate_page_angle(polys: np.ndarray) -> float: + """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the + estimated angle ccw in degrees + """ + # Compute mean left points and mean right point with respect to the reading direction (oriented polygon) + xleft = polys[:, 0, 0] + polys[:, 3, 0] + yleft = polys[:, 0, 1] + polys[:, 3, 1] + xright = polys[:, 1, 0] + polys[:, 2, 0] + yright = polys[:, 1, 1] + polys[:, 2, 1] + with np.errstate(divide="raise", invalid="raise"): + try: + return float( + np.median(np.arctan((yleft - yright) / (xright - xleft)) * 180 / np.pi) # Y axis from top to bottom! + ) + except FloatingPointError: + return 0.0 + + +def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray: + """Convert a geometry to relative coordinates + + Args: + ---- + geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) + img_shape: the height and width of the image + + Returns: + ------- + the updated geometry + """ + # Polygon + if geoms.ndim == 3 and geoms.shape[1:] == (4, 2): + polygons: np.ndarray = np.empty(geoms.shape, dtype=np.float32) + polygons[..., 0] = geoms[..., 0] / img_shape[1] + polygons[..., 1] = geoms[..., 1] / img_shape[0] + return polygons.clip(0, 1) + if geoms.ndim == 2 and geoms.shape[1] == 4: + boxes: np.ndarray = np.empty(geoms.shape, dtype=np.float32) + boxes[:, ::2] = geoms[:, ::2] / img_shape[1] + boxes[:, 1::2] = geoms[:, 1::2] / img_shape[0] + return boxes.clip(0, 1) + + raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}") + + +def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]: + """Created cropped images from list of bounding boxes + + Args: + ---- + img: input image + boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative + coordinates (xmin, ymin, xmax, ymax) + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + ------- + list of cropped images + """ + if boxes.shape[0] == 0: + return [] + if boxes.shape[1] != 4: + raise AssertionError("boxes are expected to be relative and in order (xmin, ymin, xmax, ymax)") + + # Project relative coordinates + _boxes = boxes.copy() + h, w = img.shape[:2] if channels_last else img.shape[-2:] + if not np.issubdtype(_boxes.dtype, np.integer): + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.round().astype(int) + # Add last index + _boxes[2:] += 1 + if channels_last: + return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes]) + + return deepcopy([img[:, box[1] : box[3], box[0] : box[2]] for box in _boxes]) + + +def extract_rcrops( + img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True +) -> List[np.ndarray]: + """Created cropped images from list of rotated bounding boxes + + Args: + ---- + img: input image + polys: bounding boxes of shape (N, 4, 2) + dtype: target data type of bounding boxes + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + ------- + list of cropped images + """ + if polys.shape[0] == 0: + return [] + if polys.shape[1:] != (4, 2): + raise AssertionError("polys are expected to be quadrilateral, of shape (N, 4, 2)") + + # Project relative coordinates + _boxes = polys.copy() + height, width = img.shape[:2] if channels_last else img.shape[-2:] + if not np.issubdtype(_boxes.dtype, np.integer): + _boxes[:, :, 0] *= width + _boxes[:, :, 1] *= height + + src_pts = _boxes[:, :3].astype(np.float32) + # Preserve size + d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) + d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1) + # (N, 3, 2) + dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype) + dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1 + dst_pts[:, 2, 1] = d2 - 1 + # Use a warp transformation to extract the crop + crops = [ + cv2.warpAffine( + img if channels_last else img.transpose(1, 2, 0), + # Transformation matrix + cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), + (int(d1[idx]), int(d2[idx])), + ) + for idx in range(_boxes.shape[0]) + ] + return crops diff --git a/onnxtr/utils/multithreading.py b/onnxtr/utils/multithreading.py new file mode 100644 index 0000000..adb5adb --- /dev/null +++ b/onnxtr/utils/multithreading.py @@ -0,0 +1,50 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import multiprocessing as mp +import os +from multiprocessing.pool import ThreadPool +from typing import Any, Callable, Iterable, Iterator, Optional + +from onnxtr.file_utils import ENV_VARS_TRUE_VALUES + +__all__ = ["multithread_exec"] + + +def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterator[Any]: + """Execute a given function in parallel for each element of a given sequence + + >>> from onnxtr.utils.multithreading import multithread_exec + >>> entries = [1, 4, 8] + >>> results = multithread_exec(lambda x: x ** 2, entries) + + Args: + ---- + func: function to be executed on each element of the iterable + seq: iterable + threads: number of workers to be used for multiprocessing + + Returns: + ------- + iterator of the function's results using the iterable as inputs + + Notes: + ----- + This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. + If you do not have write permissions for this directory (if you run `onnxtr` on AWS Lambda for instance), + you might want to disable multiprocessing. To achieve that, set 'ONNXTR_MULTIPROCESSING_DISABLE' to 'TRUE'. + """ + threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) + # Single-thread + if threads < 2 or os.environ.get("ONNXTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES: + results = map(func, seq) + # Multi-threading + else: + with ThreadPool(threads) as tp: + # ThreadPool's map function returns a list, but seq could be of a different type + # That's why wrapping result in map to return iterator + results = map(lambda x: x, tp.map(func, seq)) # noqa: C417 + return results diff --git a/onnxtr/utils/reconstitution.py b/onnxtr/utils/reconstitution.py new file mode 100644 index 0000000..f66b136 --- /dev/null +++ b/onnxtr/utils/reconstitution.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. +from typing import Any, Dict, Optional + +import numpy as np +from anyascii import anyascii +from PIL import Image, ImageDraw + +from .fonts import get_font + +__all__ = ["synthesize_page", "synthesize_kie_page"] + + +def synthesize_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for block in page["blocks"]: + for line in block["lines"]: + for word in line["words"]: + # Get absolute word geometry + (xmin, ymin), (xmax, ymax) = word["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * word["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response + + +def synthesize_kie_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for predictions in page["predictions"].values(): + for prediction in predictions: + # Get aboslute word geometry + (xmin, ymin), (xmax, ymax) = prediction["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * prediction["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response diff --git a/onnxtr/utils/repr.py b/onnxtr/utils/repr.py new file mode 100644 index 0000000..775fdd6 --- /dev/null +++ b/onnxtr/utils/repr.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/torch/blob/master/torch/nn/modules/module.py + +from typing import List + +__all__ = ["NestedObject"] + + +def _addindent(s_, num_spaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +class NestedObject: + """Base class for all nested objects in onnxtr""" + + _children_names: List[str] + + def extra_repr(self) -> str: + return "" + + def __repr__(self): + # We treat the extra repr like the sub-object, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + if hasattr(self, "_children_names"): + for key in self._children_names: + child = getattr(self, key) + if isinstance(child, list) and len(child) > 0: + child_str = ",\n".join([repr(subchild) for subchild in child]) + if len(child) > 1: + child_str = _addindent(f"\n{child_str},", 2) + "\n" + child_str = f"[{child_str}]" + else: + child_str = repr(child) + child_str = _addindent(child_str, 2) + child_lines.append("(" + key + "): " + child_str) + lines = extra_lines + child_lines + + main_str = self.__class__.__name__ + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str diff --git a/onnxtr/utils/visualization.py b/onnxtr/utils/visualization.py new file mode 100644 index 0000000..6c4cbfb --- /dev/null +++ b/onnxtr/utils/visualization.py @@ -0,0 +1,388 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. +import colorsys +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.figure import Figure + +from .common_types import BoundingBox, Polygon4P + +__all__ = ["visualize_page", "visualize_kie_page", "draw_boxes"] + + +def rect_patch( + geometry: BoundingBox, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False, +) -> patches.Rectangle: + """Create a matplotlib rectangular patch for the element + + Args: + ---- + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + ------- + a rectangular Patch + """ + if len(geometry) != 2 or any(not isinstance(elt, tuple) or len(elt) != 2 for elt in geometry): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + (xmin, ymin), (xmax, ymax) = geometry + # Switch to absolute coords + if preserve_aspect_ratio: + width = height = max(height, width) + xmin, w = xmin * width, (xmax - xmin) * width + ymin, h = ymin * height, (ymax - ymin) * height + + return patches.Rectangle( + (xmin, ymin), + w, + h, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def polygon_patch( + geometry: np.ndarray, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False, +) -> patches.Polygon: + """Create a matplotlib polygon patch for the element + + Args: + ---- + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + ------- + a polygon Patch + """ + if not geometry.shape == (4, 2): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + geometry[:, 0] = geometry[:, 0] * (max(width, height) if preserve_aspect_ratio else width) + geometry[:, 1] = geometry[:, 1] * (max(width, height) if preserve_aspect_ratio else height) + + return patches.Polygon( + geometry, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def create_obj_patch( + geometry: Union[BoundingBox, Polygon4P, np.ndarray], + page_dimensions: Tuple[int, int], + **kwargs: Any, +) -> patches.Patch: + """Create a matplotlib patch for the element + + Args: + ---- + geometry: bounding box (straight or rotated) of the element + page_dimensions: dimensions of the page in format (height, width) + **kwargs: keyword arguments for the patch + + Returns: + ------- + a matplotlib Patch + """ + if isinstance(geometry, tuple): + if len(geometry) == 2: # straight word BB (2 pts) + return rect_patch(geometry, page_dimensions, **kwargs) + elif len(geometry) == 4: # rotated word BB (4 pts) + return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) + elif isinstance(geometry, np.ndarray) and geometry.shape == (4, 2): # rotated line + return polygon_patch(geometry, page_dimensions, **kwargs) + raise ValueError("invalid geometry format") + + +def get_colors(num_colors: int) -> List[Tuple[float, float, float]]: + """Generate num_colors color for matplotlib + + Args: + ---- + num_colors: number of colors to generate + + Returns: + ------- + colors: list of generated colors + """ + colors = [] + for i in np.arange(0.0, 360.0, 360.0 / num_colors): + hue = i / 360.0 + lightness = (50 + np.random.rand() * 10) / 100.0 + saturation = (90 + np.random.rand() * 10) / 100.0 + colors.append(colorsys.hls_to_rgb(hue, lightness, saturation)) + return colors + + +def visualize_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = True, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from onnxtr.utils.visualization import visualize_page + >>> from onnxtr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + ---- + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + **kwargs: keyword arguments for the polygon patch + + Returns: + ------- + the matplotlib figure + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis("off") + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + for block in page["blocks"]: + if not words_only: + rect = create_obj_patch( + block["geometry"], page["dimensions"], label="block", color=(0, 1, 0), linewidth=1, **kwargs + ) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + for line in block["lines"]: + if not words_only: + rect = create_obj_patch( + line["geometry"], page["dimensions"], label="line", color=(1, 0, 0), linewidth=1, **kwargs + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + for word in line["words"]: + rect = create_obj_patch( + word["geometry"], + page["dimensions"], + label=f"{word['value']} (confidence: {word['confidence']:.2%})", + color=(0, 0, 1), + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels: + if len(word["geometry"]) == 5: + text_loc = ( + int(page["dimensions"][1] * (word["geometry"][0] - word["geometry"][2] / 2)), + int(page["dimensions"][0] * (word["geometry"][1] - word["geometry"][3] / 2)), + ) + else: + text_loc = ( + int(page["dimensions"][1] * word["geometry"][0][0]), + int(page["dimensions"][0] * word["geometry"][0][1]), + ) + + if len(word["geometry"]) == 2: + # We draw only if boxes are in straight format + ax.text( + *text_loc, + word["value"], + size=10, + alpha=0.5, + color=(0, 0, 1), + ) + + if display_artefacts: + for artefact in block["artefacts"]: + rect = create_obj_patch( + artefact["geometry"], + page["dimensions"], + label="artefact", + color=(0.5, 0.5, 0.5), + linewidth=1, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + if interactive: + import mplcursors + + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.0) + + return fig + + +def visualize_kie_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = False, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from onnxtr.utils.visualization import visualize_page + >>> from onnxtr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_kie_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + ---- + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + **kwargs: keyword arguments for the polygon patch + + Returns: + ------- + the matplotlib figure + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis("off") + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + colors = {k: color for color, k in zip(get_colors(len(page["predictions"])), page["predictions"])} + for key, value in page["predictions"].items(): + for prediction in value: + if not words_only: + rect = create_obj_patch( + prediction["geometry"], + page["dimensions"], + label=f"{key} \n {prediction['value']} (confidence: {prediction['confidence']:.2%}", + color=colors[key], + linewidth=1, + **kwargs, + ) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + if interactive: + import mplcursors + + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.0) + + return fig + + +def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, int, int]] = None, **kwargs) -> None: + """Draw an array of relative straight boxes on an image + + Args: + ---- + boxes: array of relative boxes, of shape (*, 4) + image: np array, float32 or uint8 + color: color to use for bounding box edges + **kwargs: keyword arguments from `matplotlib.pyplot.plot` + """ + h, w = image.shape[:2] + # Convert boxes to absolute coords + _boxes = deepcopy(boxes) + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.astype(np.int32) + for box in _boxes.tolist(): + xmin, ymin, xmax, ymax = box + image = cv2.rectangle( + image, (xmin, ymin), (xmax, ymax), color=color if isinstance(color, tuple) else (0, 0, 255), thickness=2 + ) + plt.imshow(image) + plt.plot(**kwargs) diff --git a/onnxtr/utils/vocabs.py b/onnxtr/utils/vocabs.py new file mode 100644 index 0000000..7899e19 --- /dev/null +++ b/onnxtr/utils/vocabs.py @@ -0,0 +1,71 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import string +from typing import Dict + +__all__ = ["VOCABS"] + + +VOCABS: Dict[str, str] = { + "digits": string.digits, + "ascii_letters": string.ascii_letters, + "punctuation": string.punctuation, + "currency": "£€¥¢฿", + "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ", + "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي", + "persian_letters": "پچڢڤگ", + "hindi_digits": "٠١٢٣٤٥٦٧٨٩", + "arabic_diacritics": "ًٌٍَُِّْ", + "arabic_punctuation": "؟؛«»—", +} + +VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"] +VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"] +VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"] +VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ" +VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ" +VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿" +VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙÚ" +VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ" +VOCABS["arabic"] = ( + VOCABS["digits"] + + VOCABS["hindi_digits"] + + VOCABS["arabic_letters"] + + VOCABS["persian_letters"] + + VOCABS["arabic_diacritics"] + + VOCABS["arabic_punctuation"] + + VOCABS["punctuation"] +) +VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ" +VOCABS["polish"] = VOCABS["english"] + "ąćęłńóśźżĄĆĘŁŃÓŚŹŻ" +VOCABS["dutch"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" +VOCABS["norwegian"] = VOCABS["english"] + "æøåÆØÅ" +VOCABS["danish"] = VOCABS["english"] + "æøåÆØÅ" +VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ" +VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ" +VOCABS["vietnamese"] = ( + VOCABS["english"] + + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ" + + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ" +) +VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" +VOCABS["multilingual"] = "".join( + dict.fromkeys( + VOCABS["french"] + + VOCABS["portuguese"] + + VOCABS["spanish"] + + VOCABS["german"] + + VOCABS["czech"] + + VOCABS["polish"] + + VOCABS["dutch"] + + VOCABS["italian"] + + VOCABS["norwegian"] + + VOCABS["danish"] + + VOCABS["finnish"] + + VOCABS["swedish"] + + "§" + ) +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e2ae960 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,181 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "onnxtr" +description = "Onnx Text Recognition (OnnxTR): Deep Learning for high-performance OCR on documents." +authors = [{name = "Felix Dittrich", email = "felixdittrich92@gmail.com"}] +maintainers = [ + {name = "Felix Dittrich"}, +] +readme = "README.md" +requires-python = ">=3.9.0,<4" +license = {file = "LICENSE"} +keywords=["OCR", "deep learning", "computer vision", "onnx", "text detection", "text recognition"] +classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dynamic = ["version"] +dependencies = [ + # For proper typing, mypy needs numpy>=1.20.0 (cf. https://github.com/numpy/numpy/pull/16515) + # Additional typing support is brought by numpy>=1.22.4, but core build sticks to >=1.16.0 + "numpy>=1.16.0,<2.0.0", + "scipy>=1.4.0,<2.0.0", + "onnx>=1.12.0,<2.0.0", + "onnxruntime>=1.11.0", + "opencv-python>=4.5.0,<5.0.0", + "pypdfium2>=4.0.0,<5.0.0", + "pyclipper>=1.2.0,<2.0.0", + "shapely>=1.6.0,<3.0.0", + "rapidfuzz>=3.0.0,<4.0.0", + "langdetect>=1.0.9,<2.0.0", + "Pillow>=9.2.0", + "defusedxml>=0.7.0", + "anyascii>=0.3.2", + "tqdm>=4.30.0", +] + +[project.optional-dependencies] +gpu = [ + "onnxruntime-gpu>=1.11.0", +] +html = [ + "weasyprint>=55.0", +] +viz = [ + "matplotlib>=3.1.0", + "mplcursors>=0.3", +] +testing = [ + "pytest>=5.3.2", + "coverage[toml]>=4.5.4", + "requests>=2.20.0", +] +quality = [ + "ruff>=0.1.5", + "mypy>=0.812", + "pre-commit>=2.17.0", +] +docs = [ + "sphinx>=3.0.0,!=3.5.0", + "sphinxemoji>=0.1.8", + "sphinx-copybutton>=0.3.1", + "docutils<0.22", + "recommonmark>=0.7.1", + "sphinx-markdown-tables>=0.0.15", + "sphinx-tabs>=3.3.0", + "furo>=2022.3.4", +] +dev = [ + # HTML + "weasyprint>=55.0", + # Visualization + "matplotlib>=3.1.0", + "mplcursors>=0.3", + # Testing + "pytest>=5.3.2", + "coverage[toml]>=4.5.4", + "requests>=2.20.0", + # Quality + "ruff>=0.1.5", + "mypy>=0.812", + "pre-commit>=2.17.0", + # Documentation + "sphinx>=3.0.0,!=3.5.0", + "sphinxemoji>=0.1.8", + "sphinx-copybutton>=0.3.1", + "docutils<0.22", + "recommonmark>=0.7.1", + "sphinx-markdown-tables>=0.0.15", + "sphinx-tabs>=3.3.0", + "furo>=2022.3.4", +] + +[project.urls] +#documentation = "https://mindee.github.io/doctr" +repository = "https://github.com/felixdittrich92/OnnxTR" +tracker = "https://github.com/felixdittrich92/OnnxTR/issues" +#changelog = "https://mindee.github.io/doctr/changelog.html" + +[tool.setuptools] +zip-safe = true + +[tool.setuptools.packages.find] +exclude = ["api*", "demo*", "docs*", "tests*"] + +[tool.mypy] +files = "onnxtr/" +show_error_codes = true +pretty = true +warn_unused_ignores = true +warn_redundant_casts = true +no_implicit_optional = true +check_untyped_defs = true +implicit_reexport = false + +[[tool.mypy.overrides]] +module = [ + "onnxruntime.*", + "PIL.*", + "scipy.*", + "cv2.*", + "matplotlib.*", + "numpy.*", + "onnx.*", + "pyclipper.*", + "shapely.*", + "mplcursors.*", + "defusedxml.*", + "weasyprint.*", + "huggingface_hub.*", + "pypdfium2.*", + "langdetect.*", + "rapidfuzz.*", + "anyascii.*", + "tqdm.*", +] +ignore_missing_imports = true + +[tool.ruff] +exclude = [".git", "venv*", "build", "**/__init__.py"] +line-length = 120 +target-version = "py39" +preview=true + +[tool.ruff.lint] +select = [ + # https://docs.astral.sh/ruff/rules/ + "E", "W", "F", "I", "N", "Q", "C4", "T10", "LOG", + "D101", "D103", "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", "D207" # pydocstyle +] +ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408"] + +[tool.ruff.lint.isort] +known-first-party = ["onnxtr", "app", "utils"] +known-third-party = ["fastapi", "onnxruntime", "cv2"] + +[tool.ruff.lint.per-file-ignores] +"onnxtr/models/**.py" = ["N806", "F841"] +"onnxtr/datasets/**.py" = ["N806"] +"tests/**.py" = ["D"] +"docs/**.py" = ["D"] +".github/**.py" = ["D"] + + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" + +[tool.coverage.run] +source = ["onnxtr"] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..31702f8 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +# Copyright (C) 2021-2024, Mindee | Felix Dittrich. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from pathlib import Path + +from setuptools import setup + +PKG_NAME = "onnxtr" +VERSION = os.getenv("BUILD_VERSION", "0.0.1a0") + + +if __name__ == "__main__": + print(f"Building wheel {PKG_NAME}-{VERSION}") + + # Dynamically set the __version__ attribute + cwd = Path(__file__).parent.absolute() + with open(cwd.joinpath("onnxtr", "version.py"), "w", encoding="utf-8") as f: + f.write(f"__version__ = '{VERSION}'\n") + + setup(name=PKG_NAME, version=VERSION) diff --git a/tests/common/test_contrib.py b/tests/common/test_contrib.py new file mode 100644 index 0000000..e71ebca --- /dev/null +++ b/tests/common/test_contrib.py @@ -0,0 +1,37 @@ +import numpy as np +import pytest + +from onnxtr.contrib import artefacts +from onnxtr.contrib.base import _BasePredictor +from onnxtr.io import DocumentFile + + +def test_base_predictor(): + # check that we need to provide either a url or a model_path + with pytest.raises(ValueError): + _ = _BasePredictor(batch_size=2) + + predictor = _BasePredictor(batch_size=2, url=artefacts.default_cfgs["yolov8_artefact"]["url"]) + # check that we need to implement preprocess and postprocess + with pytest.raises(NotImplementedError): + predictor.preprocess(np.zeros((10, 10, 3))) + with pytest.raises(NotImplementedError): + predictor.postprocess([np.zeros((10, 10, 3))], [[np.zeros((10, 10, 3))]]) + + +def test_artefact_detector(mock_artefact_image_stream): + doc = DocumentFile.from_images([mock_artefact_image_stream]) + detector = artefacts.ArtefactDetector(batch_size=2, conf_threshold=0.5, iou_threshold=0.5) + results = detector(doc) + assert isinstance(results, list) and len(results) == 1 and isinstance(results[0], list) + assert all(isinstance(artefact, dict) for artefact in results[0]) + # check result keys + assert all(key in results[0][0] for key in ["label", "confidence", "box"]) + assert all(len(artefact["box"]) == 4 for artefact in results[0]) + assert all(isinstance(coord, int) for box in results[0] for coord in box["box"]) + assert all(isinstance(artefact["confidence"], float) for artefact in results[0]) + assert all(isinstance(artefact["label"], str) for artefact in results[0]) + # check results for the mock image are 9 artefacts + assert len(results[0]) == 9 + # test visualization non-blocking for tests + detector.show(block=False) diff --git a/tests/common/test_core.py b/tests/common/test_core.py new file mode 100644 index 0000000..04e6ebb --- /dev/null +++ b/tests/common/test_core.py @@ -0,0 +1,14 @@ +import pytest + +import onnxtr +from onnxtr.file_utils import requires_package + + +def test_version(): + assert len(onnxtr.__version__.split(".")) == 3 + + +def test_requires_package(): + requires_package("numpy") # availbable + with pytest.raises(ImportError): # not available + requires_package("non_existent_package") diff --git a/tests/common/test_headers.py b/tests/common/test_headers.py new file mode 100644 index 0000000..cf7ecf3 --- /dev/null +++ b/tests/common/test_headers.py @@ -0,0 +1,23 @@ +"""Test for python files copyright headers.""" + +from datetime import datetime +from pathlib import Path + + +def test_copyright_header(): + copyright_header = "".join([ + f"# Copyright (C) {2021}-{datetime.now().year}, Mindee | Felix Dittrich.\n\n", + "# This program is licensed under the Apache License 2.0.\n", + "# See LICENSE or go to for full license details.\n", + ]) + excluded_files = ["__init__.py", "version.py"] + invalid_files = [] + locations = [".github", "onnxtr"] + + for location in locations: + for source_path in Path(__file__).parent.parent.parent.joinpath(location).rglob("*.py"): + if source_path.name not in excluded_files: + source_path_content = source_path.read_text() + if copyright_header not in source_path_content: + invalid_files.append(source_path) + assert len(invalid_files) == 0, f"Invalid copyright header in the following files: {invalid_files}" diff --git a/tests/common/test_io.py b/tests/common/test_io.py new file mode 100644 index 0000000..2e8d2ce --- /dev/null +++ b/tests/common/test_io.py @@ -0,0 +1,99 @@ +from io import BytesIO +from pathlib import Path + +import numpy as np +import pytest +import requests + +from onnxtr import io + + +def _check_doc_content(doc_tensors, num_pages): + # 1 doc of 8 pages + assert len(doc_tensors) == num_pages + assert all(isinstance(page, np.ndarray) for page in doc_tensors) + assert all(page.dtype == np.uint8 for page in doc_tensors) + + +def test_read_pdf(mock_pdf): + doc = io.read_pdf(mock_pdf) + _check_doc_content(doc, 2) + + # Test with Path + doc = io.read_pdf(Path(mock_pdf)) + _check_doc_content(doc, 2) + + with open(mock_pdf, "rb") as f: + doc = io.read_pdf(f.read()) + _check_doc_content(doc, 2) + + # Wrong input type + with pytest.raises(TypeError): + _ = io.read_pdf(123) + + # Wrong path + with pytest.raises(FileNotFoundError): + _ = io.read_pdf("my_imaginary_file.pdf") + + +def test_read_img_as_numpy(tmpdir_factory, mock_pdf): + # Wrong input type + with pytest.raises(TypeError): + _ = io.read_img_as_numpy(123) + + # Non-existing file + with pytest.raises(FileNotFoundError): + io.read_img_as_numpy("my_imaginary_file.jpg") + + # Invalid image + with pytest.raises(ValueError): + io.read_img_as_numpy(str(mock_pdf)) + + # From path + url = "https://doctr-static.mindee.com/models?id=v0.2.1/Grace_Hopper.jpg&src=0" + file = BytesIO(requests.get(url).content) + tmp_path = str(tmpdir_factory.mktemp("data").join("mock_img_file.jpg")) + with open(tmp_path, "wb") as f: + f.write(file.getbuffer()) + + # Path & stream + with open(tmp_path, "rb") as f: + page_stream = io.read_img_as_numpy(f.read()) + + for page in (io.read_img_as_numpy(tmp_path), page_stream): + # Data type + assert isinstance(page, np.ndarray) + assert page.dtype == np.uint8 + # Shape + assert page.shape == (606, 517, 3) + + # RGB + bgr_page = io.read_img_as_numpy(tmp_path, rgb_output=False) + assert np.all(page == bgr_page[..., ::-1]) + + # Resize + target_size = (200, 150) + resized_page = io.read_img_as_numpy(tmp_path, target_size) + assert resized_page.shape[:2] == target_size + + +def test_read_html(): + url = "https://www.google.com" + pdf_stream = io.read_html(url) + assert isinstance(pdf_stream, bytes) + + +def test_document_file(mock_pdf, mock_image_stream): + pages = io.DocumentFile.from_images(mock_image_stream) + _check_doc_content(pages, 1) + + assert isinstance(io.DocumentFile.from_pdf(mock_pdf), list) + assert isinstance(io.DocumentFile.from_url("https://www.google.com"), list) + + +def test_pdf(mock_pdf): + pages = io.DocumentFile.from_pdf(mock_pdf) + + # As images + num_pages = 2 + _check_doc_content(pages, num_pages) diff --git a/tests/common/test_io_elements.py b/tests/common/test_io_elements.py new file mode 100644 index 0000000..9a638d8 --- /dev/null +++ b/tests/common/test_io_elements.py @@ -0,0 +1,283 @@ +from xml.etree.ElementTree import ElementTree + +import numpy as np +import pytest + +from onnxtr.io import elements + + +def _mock_words(size=(1.0, 1.0), offset=(0, 0), confidence=0.9): + return [ + elements.Word( + "hello", + confidence, + ((offset[0], offset[1]), (size[0] / 2 + offset[0], size[1] / 2 + offset[1])), + {"value": 0, "confidence": None}, + ), + elements.Word( + "world", + confidence, + ((size[0] / 2 + offset[0], size[1] / 2 + offset[1]), (size[0] + offset[0], size[1] + offset[1])), + {"value": 0, "confidence": None}, + ), + ] + + +def _mock_artefacts(size=(1, 1), offset=(0, 0), confidence=0.8): + sub_size = (size[0] / 2, size[1] / 2) + return [ + elements.Artefact( + "qr_code", confidence, ((offset[0], offset[1]), (sub_size[0] + offset[0], sub_size[1] + offset[1])) + ), + elements.Artefact( + "qr_code", + confidence, + ((sub_size[0] + offset[0], sub_size[1] + offset[1]), (size[0] + offset[0], size[1] + offset[1])), + ), + ] + + +def _mock_lines(size=(1, 1), offset=(0, 0)): + sub_size = (size[0] / 2, size[1] / 2) + return [ + elements.Line(_mock_words(size=sub_size, offset=offset)), + elements.Line(_mock_words(size=sub_size, offset=(offset[0] + sub_size[0], offset[1] + sub_size[1]))), + ] + + +def _mock_blocks(size=(1, 1), offset=(0, 0)): + sub_size = (size[0] / 4, size[1] / 4) + return [ + elements.Block( + _mock_lines(size=sub_size, offset=offset), + _mock_artefacts(size=sub_size, offset=(offset[0] + sub_size[0], offset[1] + sub_size[1])), + ), + elements.Block( + _mock_lines(size=sub_size, offset=(offset[0] + 2 * sub_size[0], offset[1] + 2 * sub_size[1])), + _mock_artefacts(size=sub_size, offset=(offset[0] + 3 * sub_size[0], offset[1] + 3 * sub_size[1])), + ), + ] + + +def _mock_pages(block_size=(1, 1), block_offset=(0, 0)): + return [ + elements.Page( + np.random.randint(0, 255, (300, 200, 3), dtype=np.uint8), + _mock_blocks(block_size, block_offset), + 0, + (300, 200), + {"value": 0.0, "confidence": 1.0}, + {"value": "EN", "confidence": 0.8}, + ), + elements.Page( + np.random.randint(0, 255, (500, 1000, 3), dtype=np.uint8), + _mock_blocks(block_size, block_offset), + 1, + (500, 1000), + {"value": 0.15, "confidence": 0.8}, + {"value": "FR", "confidence": 0.7}, + ), + ] + + +def test_element(): + with pytest.raises(KeyError): + elements.Element(sub_elements=[1]) + + +def test_word(): + word_str = "hello" + conf = 0.8 + geom = ((0, 0), (1, 1)) + crop_orientation = {"value": 0, "confidence": None} + word = elements.Word(word_str, conf, geom, crop_orientation) + + # Attribute checks + assert word.value == word_str + assert word.confidence == conf + assert word.geometry == geom + assert word.crop_orientation == crop_orientation + + # Render + assert word.render() == word_str + + # Export + assert word.export() == { + "value": word_str, + "confidence": conf, + "geometry": geom, + "crop_orientation": crop_orientation, + } + + # Repr + assert word.__repr__() == f"Word(value='hello', confidence={conf:.2})" + + # Class method + state_dict = { + "value": "there", + "confidence": 0.1, + "geometry": ((0, 0), (0.5, 0.5)), + "crop_orientation": crop_orientation, + } + word = elements.Word.from_dict(state_dict) + assert word.export() == state_dict + + +def test_line(): + geom = ((0, 0), (0.5, 0.5)) + words = _mock_words(size=geom[1], offset=geom[0]) + line = elements.Line(words) + + # Attribute checks + assert len(line.words) == len(words) + assert all(isinstance(w, elements.Word) for w in line.words) + assert line.geometry == geom + + # Render + assert line.render() == "hello world" + + # Export + assert line.export() == {"words": [w.export() for w in words], "geometry": geom} + + # Repr + words_str = " " * 4 + ",\n ".join(repr(word) for word in words) + "," + assert line.__repr__() == f"Line(\n (words): [\n{words_str}\n ]\n)" + + # Ensure that words repr does't span on several lines when there are none + assert repr(elements.Line([], ((0, 0), (1, 1)))) == "Line(\n (words): []\n)" + + # from dict + state_dict = { + "words": [ + { + "value": "there", + "confidence": 0.1, + "geometry": ((0, 0), (1.0, 1.0)), + "crop_orientation": {"value": 0, "confidence": None}, + } + ], + "geometry": ((0, 0), (1.0, 1.0)), + } + line = elements.Line.from_dict(state_dict) + assert line.export() == state_dict + + +def test_artefact(): + artefact_type = "qr_code" + conf = 0.8 + geom = ((0, 0), (1, 1)) + artefact = elements.Artefact(artefact_type, conf, geom) + + # Attribute checks + assert artefact.type == artefact_type + assert artefact.confidence == conf + assert artefact.geometry == geom + + # Render + assert artefact.render() == "[QR_CODE]" + + # Export + assert artefact.export() == {"type": artefact_type, "confidence": conf, "geometry": geom} + + # Repr + assert artefact.__repr__() == f"Artefact(type='{artefact_type}', confidence={conf:.2})" + + +def test_block(): + geom = ((0, 0), (1, 1)) + sub_size = (geom[1][0] / 2, geom[1][0] / 2) + lines = _mock_lines(size=sub_size, offset=geom[0]) + artefacts = _mock_artefacts(size=sub_size, offset=sub_size) + block = elements.Block(lines, artefacts) + + # Attribute checks + assert len(block.lines) == len(lines) + assert len(block.artefacts) == len(artefacts) + assert all(isinstance(w, elements.Line) for w in block.lines) + assert all(isinstance(a, elements.Artefact) for a in block.artefacts) + assert block.geometry == geom + + # Render + assert block.render() == "hello world\nhello world" + + # Export + assert block.export() == { + "lines": [line.export() for line in lines], + "artefacts": [artefact.export() for artefact in artefacts], + "geometry": geom, + } + + +def test_page(): + page = np.zeros((300, 200, 3), dtype=np.uint8) + page_idx = 0 + page_size = (300, 200) + orientation = {"value": 0.0, "confidence": 0.0} + language = {"value": "EN", "confidence": 0.8} + blocks = _mock_blocks() + page = elements.Page(page, blocks, page_idx, page_size, orientation, language) + + # Attribute checks + assert len(page.blocks) == len(blocks) + assert all(isinstance(b, elements.Block) for b in page.blocks) + assert isinstance(page.page, np.ndarray) + assert page.page_idx == page_idx + assert page.dimensions == page_size + assert page.orientation == orientation + assert page.language == language + + # Render + assert page.render() == "hello world\nhello world\n\nhello world\nhello world" + + # Export + assert page.export() == { + "blocks": [b.export() for b in blocks], + "page_idx": page_idx, + "dimensions": page_size, + "orientation": orientation, + "language": language, + } + + # Export XML + assert ( + isinstance(page.export_as_xml(), tuple) + and isinstance(page.export_as_xml()[0], (bytes, bytearray)) + and isinstance(page.export_as_xml()[1], ElementTree) + ) + + # Repr + assert "\n".join(repr(page).split("\n")[:2]) == f"Page(\n dimensions={page_size!r}" + + # Show + page.show(block=False) + + # Synthesize + img = page.synthesize() + assert isinstance(img, np.ndarray) + assert img.shape == (*page_size, 3) + + +def test_document(): + pages = _mock_pages() + doc = elements.Document(pages) + + # Attribute checks + assert len(doc.pages) == len(pages) + assert all(isinstance(p, elements.Page) for p in doc.pages) + + # Render + page_export = "hello world\nhello world\n\nhello world\nhello world" + assert doc.render() == f"{page_export}\n\n\n\n{page_export}" + + # Export + assert doc.export() == {"pages": [p.export() for p in pages]} + + # Export XML + assert isinstance(doc.export_as_xml(), list) and len(doc.export_as_xml()) == len(pages) + + # Show + doc.show(block=False) + + # Synthesize + img_list = doc.synthesize() + assert isinstance(img_list, list) and len(img_list) == len(pages) diff --git a/tests/common/test_models.py b/tests/common/test_models.py new file mode 100644 index 0000000..be7bfab --- /dev/null +++ b/tests/common/test_models.py @@ -0,0 +1,71 @@ +from io import BytesIO + +import cv2 +import numpy as np +import pytest +import requests + +from onnxtr.io import reader +from onnxtr.models._utils import estimate_orientation, get_language +from onnxtr.utils import geometry + + +@pytest.fixture(scope="function") +def mock_image(tmpdir_factory): + url = "https://doctr-static.mindee.com/models?id=v0.2.1/bitmap30.png&src=0" + file = BytesIO(requests.get(url).content) + tmp_path = str(tmpdir_factory.mktemp("data").join("mock_bitmap.jpg")) + with open(tmp_path, "wb") as f: + f.write(file.getbuffer()) + image = reader.read_img_as_numpy(tmp_path) + return image + + +@pytest.fixture(scope="function") +def mock_bitmap(mock_image): + bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY) / 255.0) + bitmap = np.expand_dims(bitmap, axis=-1) + return bitmap + + +def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip): + assert estimate_orientation(mock_image * 0) == 0 + + # test binarized image + angle = estimate_orientation(mock_bitmap) + assert abs(angle - 30.0) < 1.0 + + angle = estimate_orientation(mock_bitmap * 255) + assert abs(angle - 30.0) < 1.0 + + angle = estimate_orientation(mock_image) + assert abs(angle - 30.0) < 1.0 + + rotated = geometry.rotate_image(mock_image, -angle) + angle_rotated = estimate_orientation(rotated) + assert abs(angle_rotated) < 1.0 + + mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip) + assert (estimate_orientation(mock_tilted_payslip) - 30.0) < 1.0 + + rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True) + angle_rotated = estimate_orientation(rotated) + assert abs(angle_rotated) < 1.0 + + with pytest.raises(AssertionError): + estimate_orientation(np.ones((10, 10, 10))) + + +def test_get_lang(): + sentence = "This is a test sentence." + expected_lang = "en" + threshold_prob = 0.99 + + lang = get_language(sentence) + + assert lang[0] == expected_lang + assert lang[1] > threshold_prob + + lang = get_language("a") + assert lang[0] == "unknown" + assert lang[1] == 0.0 diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py new file mode 100644 index 0000000..a06dcdb --- /dev/null +++ b/tests/common/test_models_builder.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +from onnxtr.io import Document +from onnxtr.models import builder + +words_per_page = 10 + + +def test_documentbuilder(): + num_pages = 2 + + # Don't resolve lines + doc_builder = builder.DocumentBuilder(resolve_lines=False, resolve_blocks=False) + pages = [np.zeros((100, 200, 3))] * num_pages + boxes = np.random.rand(words_per_page, 6) # array format + boxes[:2] *= boxes[2:4] + # Arg consistency check + with pytest.raises(ValueError): + doc_builder( + pages, + [boxes, boxes], + [("hello", 1.0)] * 3, + [(100, 200), (100, 200)], + [{"value": 0, "confidence": None}] * 3, + ) + out = doc_builder( + pages, + [boxes, boxes], + [[("hello", 1.0)] * words_per_page] * num_pages, + [(100, 200), (100, 200)], + [[{"value": 0, "confidence": None}] * words_per_page] * num_pages, + ) + assert isinstance(out, Document) + assert len(out.pages) == num_pages + assert all(isinstance(page.page, np.ndarray) for page in out.pages) and all( + page.page.shape == (100, 200, 3) for page in out.pages + ) + # 1 Block & 1 line per page + assert len(out.pages[0].blocks) == 1 and len(out.pages[0].blocks[0].lines) == 1 + assert len(out.pages[0].blocks[0].lines[0].words) == words_per_page + + # Resolve lines + doc_builder = builder.DocumentBuilder(resolve_lines=True, resolve_blocks=True) + out = doc_builder( + pages, + [boxes, boxes], + [[("hello", 1.0)] * words_per_page] * num_pages, + [(100, 200), (100, 200)], + [[{"value": 0, "confidence": None}] * words_per_page] * num_pages, + ) + + # No detection + boxes = np.zeros((0, 5)) + out = doc_builder(pages, [boxes, boxes], [[], []], [(100, 200), (100, 200)], [[]]) + assert len(out.pages[0].blocks) == 0 + + # Rotated boxes to export as straight boxes + boxes = np.array([ + [[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]], + [[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]], + ]) + doc_builder_2 = builder.DocumentBuilder(resolve_blocks=False, resolve_lines=False, export_as_straight_boxes=True) + out = doc_builder_2( + [np.zeros((100, 100, 3))], + [boxes], + [[("hello", 0.99), ("word", 0.99)]], + [(100, 100)], + [[{"value": 0, "confidence": None}] * 2], + ) + assert out.pages[0].blocks[0].lines[0].words[-1].geometry == ((0.45, 0.5), (0.6, 0.65)) + + # Repr + assert ( + repr(doc_builder) == "DocumentBuilder(resolve_lines=True, " + "resolve_blocks=True, paragraph_break=0.035, export_as_straight_boxes=False)" + ) + + +@pytest.mark.parametrize( + "input_boxes, sorted_idxs", + [ + [[[0, 0.5, 0.1, 0.6], [0, 0.3, 0.2, 0.4], [0, 0, 0.1, 0.1]], [2, 1, 0]], # vertical + [[[0.7, 0.5, 0.85, 0.6], [0.2, 0.3, 0.4, 0.4], [0, 0, 0.1, 0.1]], [2, 1, 0]], # diagonal + [[[0, 0.5, 0.1, 0.6], [0.15, 0.5, 0.25, 0.6], [0.5, 0.5, 0.6, 0.6]], [0, 1, 2]], # same line, 2p + [[[0, 0.5, 0.1, 0.6], [0.2, 0.49, 0.35, 0.59], [0.8, 0.52, 0.9, 0.63]], [0, 1, 2]], # ~same line + [[[0, 0.3, 0.4, 0.45], [0.5, 0.28, 0.75, 0.42], [0, 0.45, 0.1, 0.55]], [0, 1, 2]], # 2 lines + [[[0, 0.3, 0.4, 0.35], [0.75, 0.28, 0.95, 0.42], [0, 0.45, 0.1, 0.55]], [0, 1, 2]], # 2 lines + [ + [ + [[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]], + [[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]], + ], + [0, 1], + ], # rot + ], +) +def test_sort_boxes(input_boxes, sorted_idxs): + doc_builder = builder.DocumentBuilder() + assert doc_builder._sort_boxes(np.asarray(input_boxes))[0].tolist() == sorted_idxs + + +@pytest.mark.parametrize( + "input_boxes, lines", + [ + [[[0, 0.5, 0.1, 0.6], [0, 0.3, 0.2, 0.4], [0, 0, 0.1, 0.1]], [[2], [1], [0]]], # vertical + [[[0.7, 0.5, 0.85, 0.6], [0.2, 0.3, 0.4, 0.4], [0, 0, 0.1, 0.1]], [[2], [1], [0]]], # diagonal + [[[0, 0.5, 0.14, 0.6], [0.15, 0.5, 0.25, 0.6], [0.5, 0.5, 0.6, 0.6]], [[0, 1], [2]]], # same line, 2p + [[[0, 0.5, 0.18, 0.6], [0.2, 0.48, 0.35, 0.58], [0.8, 0.52, 0.9, 0.63]], [[0, 1], [2]]], # ~same line + [[[0, 0.3, 0.48, 0.45], [0.5, 0.28, 0.75, 0.42], [0, 0.45, 0.1, 0.55]], [[0, 1], [2]]], # 2 lines + [[[0, 0.3, 0.4, 0.35], [0.75, 0.28, 0.95, 0.42], [0, 0.45, 0.1, 0.55]], [[0], [1], [2]]], # 2 lines + [ + [ + [[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]], + [[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]], + ], + [[0], [1]], + ], # rot + ], +) +def test_resolve_lines(input_boxes, lines): + doc_builder = builder.DocumentBuilder() + assert doc_builder._resolve_lines(np.asarray(input_boxes)) == lines diff --git a/tests/common/test_models_detection.py b/tests/common/test_models_detection.py new file mode 100644 index 0000000..1dc8621 --- /dev/null +++ b/tests/common/test_models_detection.py @@ -0,0 +1,60 @@ +import numpy as np +import pytest + +from onnxtr.models.detection.postprocessor.base import GeneralDetectionPostProcessor + + +def test_postprocessor(): + postprocessor = GeneralDetectionPostProcessor(assume_straight_pages=True) + r_postprocessor = GeneralDetectionPostProcessor(assume_straight_pages=False) + with pytest.raises(AssertionError): + postprocessor(np.random.rand(2, 512, 512).astype(np.float32)) + mock_batch = np.random.rand(2, 512, 512, 1).astype(np.float32) + out = postprocessor(mock_batch) + r_out = r_postprocessor(mock_batch) + # Batch composition + assert isinstance(out, list) + assert len(out) == 2 + assert all(isinstance(sample, list) and all(isinstance(v, np.ndarray) for v in sample) for sample in out) + assert all(all(v.shape[1] == 5 for v in sample) for sample in out) + assert all(all(v.shape[1] == 4 and v.shape[2] == 2 for v in sample) for sample in r_out) + # Relative coords + assert all(all(np.all(np.logical_and(v[:, :4] >= 0, v[:, :4] <= 1)) for v in sample) for sample in out) + assert all(all(np.all(np.logical_and(v[:, :4] >= 0, v[:, :4] <= 1)) for v in sample) for sample in r_out) + # Repr + assert repr(postprocessor) == "GeneralDetectionPostProcessor(bin_thresh=0.1, box_thresh=0.1)" + # Edge case when the expanded points of the polygon has two lists + issue_points = np.array( + [ + [869, 561], + [923, 581], + [925, 595], + [915, 583], + [889, 583], + [905, 593], + [882, 601], + [901, 595], + [904, 604], + [876, 608], + [915, 614], + [911, 605], + [925, 601], + [930, 616], + [911, 617], + [900, 636], + [931, 637], + [904, 649], + [932, 649], + [932, 628], + [918, 627], + [934, 624], + [935, 573], + [909, 569], + [934, 562], + ], + dtype=np.int32, + ) + out = postprocessor.polygon_to_box(issue_points) + r_out = r_postprocessor.polygon_to_box(issue_points) + assert isinstance(out, tuple) and len(out) == 4 + assert isinstance(r_out, np.ndarray) and r_out.shape == (4, 2) diff --git a/tests/common/test_models_recognition_predictor.py b/tests/common/test_models_recognition_predictor.py new file mode 100644 index 0000000..734239a --- /dev/null +++ b/tests/common/test_models_recognition_predictor.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from onnxtr.models.recognition.predictor._utils import remap_preds, split_crops + + +@pytest.mark.parametrize( + "crops, max_ratio, target_ratio, dilation, channels_last, num_crops", + [ + # No split required + [[np.zeros((32, 128, 3), dtype=np.uint8)], 8, 4, 1.4, True, 1], + [[np.zeros((3, 32, 128), dtype=np.uint8)], 8, 4, 1.4, False, 1], + # Split required + [[np.zeros((32, 1024, 3), dtype=np.uint8)], 8, 6, 1.4, True, 5], + [[np.zeros((3, 32, 1024), dtype=np.uint8)], 8, 6, 1.4, False, 5], + ], +) +def test_split_crops(crops, max_ratio, target_ratio, dilation, channels_last, num_crops): + new_crops, crop_map, should_remap = split_crops(crops, max_ratio, target_ratio, dilation, channels_last) + assert len(new_crops) == num_crops + assert len(crop_map) == len(crops) + assert should_remap == (len(crops) != len(new_crops)) + + +@pytest.mark.parametrize( + "preds, crop_map, dilation, pred", + [ + # Nothing to remap + [[("hello", 0.5)], [0], 1.4, [("hello", 0.5)]], + # Merge + [[("hellowo", 0.5), ("loworld", 0.6)], [(0, 2)], 1.4, [("helloworld", 0.5)]], + ], +) +def test_remap_preds(preds, crop_map, dilation, pred): + preds = remap_preds(preds, crop_map, dilation) + assert len(preds) == len(pred) + assert preds == pred + assert all(isinstance(pred, tuple) for pred in preds) + assert all(isinstance(pred[0], str) and isinstance(pred[1], float) for pred in preds) diff --git a/tests/common/test_models_recognition_utils.py b/tests/common/test_models_recognition_utils.py new file mode 100644 index 0000000..b844d99 --- /dev/null +++ b/tests/common/test_models_recognition_utils.py @@ -0,0 +1,31 @@ +import pytest + +from onnxtr.models.recognition.utils import merge_multi_strings, merge_strings + + +@pytest.mark.parametrize( + "a, b, merged", + [ + ["abc", "def", "abcdef"], + ["abcd", "def", "abcdef"], + ["abcde", "def", "abcdef"], + ["abcdef", "def", "abcdef"], + ["abcccc", "cccccc", "abcccccccc"], + ["abc", "", "abc"], + ["", "abc", "abc"], + ], +) +def test_merge_strings(a, b, merged): + assert merged == merge_strings(a, b, 1.4) + + +@pytest.mark.parametrize( + "seq_list, merged", + [ + [["abc", "def"], "abcdef"], + [["abcd", "def", "efgh", "ijk"], "abcdefghijk"], + [["abcdi", "defk", "efghi", "aijk"], "abcdefghijk"], + ], +) +def test_merge_multi_strings(seq_list, merged): + assert merged == merge_multi_strings(seq_list, 1.4) diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py new file mode 100644 index 0000000..4640904 --- /dev/null +++ b/tests/common/test_transforms.py @@ -0,0 +1 @@ +# TODO diff --git a/tests/common/test_utils_data.py b/tests/common/test_utils_data.py new file mode 100644 index 0000000..0000b18 --- /dev/null +++ b/tests/common/test_utils_data.py @@ -0,0 +1,46 @@ +import os +from pathlib import PosixPath +from unittest.mock import patch + +import pytest + +from onnxtr.utils.data import download_from_url + + +@patch("onnxtr.utils.data._urlretrieve") +@patch("pathlib.Path.mkdir") +@patch.dict(os.environ, {"HOME": "/"}, clear=True) +def test_download_from_url(mkdir_mock, urlretrieve_mock): + download_from_url("test_url") + urlretrieve_mock.assert_called_with("test_url", PosixPath("/.cache/onnxtr/test_url")) + + +@patch.dict(os.environ, {"ONNXTR_CACHE_DIR": "/test"}, clear=True) +@patch("onnxtr.utils.data._urlretrieve") +@patch("pathlib.Path.mkdir") +def test_download_from_url_customizing_cache_dir(mkdir_mock, urlretrieve_mock): + download_from_url("test_url") + urlretrieve_mock.assert_called_with("test_url", PosixPath("/test/test_url")) + + +@patch.dict(os.environ, {"HOME": "/"}, clear=True) +@patch("pathlib.Path.mkdir", side_effect=OSError) +@patch("logging.error") +def test_download_from_url_error_creating_directory(logging_mock, mkdir_mock): + with pytest.raises(OSError): + download_from_url("test_url") + logging_mock.assert_called_with( + "Failed creating cache direcotry at /.cache/onnxtr." + " You can change default cache directory using 'ONNXTR_CACHE_DIR' environment variable if needed." + ) + + +@patch.dict(os.environ, {"HOME": "/", "ONNXTR_CACHE_DIR": "/test"}, clear=True) +@patch("pathlib.Path.mkdir", side_effect=OSError) +@patch("logging.error") +def test_download_from_url_error_creating_directory_with_env_var(logging_mock, mkdir_mock): + with pytest.raises(OSError): + download_from_url("test_url") + logging_mock.assert_called_with( + "Failed creating cache direcotry at /test using path from 'ONNXTR_CACHE_DIR' environment variable." + ) diff --git a/tests/common/test_utils_fonts.py b/tests/common/test_utils_fonts.py new file mode 100644 index 0000000..e8ea807 --- /dev/null +++ b/tests/common/test_utils_fonts.py @@ -0,0 +1,10 @@ +from PIL.ImageFont import FreeTypeFont, ImageFont + +from onnxtr.utils.fonts import get_font + + +def test_get_font(): + # Attempts to load recommended OS font + font = get_font() + + assert isinstance(font, (ImageFont, FreeTypeFont)) diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py new file mode 100644 index 0000000..548b83c --- /dev/null +++ b/tests/common/test_utils_geometry.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from math import hypot + +import numpy as np +import pytest + +from onnxtr.io import DocumentFile +from onnxtr.utils import geometry + + +def test_bbox_to_polygon(): + assert geometry.bbox_to_polygon(((0, 0), (1, 1))) == ((0, 0), (1, 0), (0, 1), (1, 1)) + + +def test_polygon_to_bbox(): + assert geometry.polygon_to_bbox(((0, 0), (1, 0), (0, 1), (1, 1))) == ((0, 0), (1, 1)) + + +def test_resolve_enclosing_bbox(): + assert geometry.resolve_enclosing_bbox([((0, 0.5), (1, 0)), ((0.5, 0), (1, 0.25))]) == ((0, 0), (1, 0.5)) + pred = geometry.resolve_enclosing_bbox(np.array([[0.1, 0.1, 0.2, 0.2, 0.9], [0.15, 0.15, 0.2, 0.2, 0.8]])) + assert pred.all() == np.array([0.1, 0.1, 0.2, 0.2, 0.85]).all() + + +def test_resolve_enclosing_rbbox(): + pred = geometry.resolve_enclosing_rbbox([ + np.asarray([[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]]), + np.asarray([[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]]), + ]) + target1 = np.asarray([[0.55, 0.65], [0.05, 0.15], [0.1, 0.1], [0.6, 0.6]]) + target2 = np.asarray([[0.05, 0.15], [0.1, 0.1], [0.6, 0.6], [0.55, 0.65]]) + assert np.all(target1 - pred <= 1e-3) or np.all(target2 - pred <= 1e-3) + + +def test_remap_boxes(): + pred = geometry.remap_boxes( + np.asarray([[[0.25, 0.25], [0.25, 0.75], [0.75, 0.25], [0.75, 0.75]]]), (10, 10), (20, 20) + ) + target = np.asarray([[[0.375, 0.375], [0.375, 0.625], [0.625, 0.375], [0.625, 0.625]]]) + assert np.all(pred == target) + + pred = geometry.remap_boxes( + np.asarray([[[0.25, 0.25], [0.25, 0.75], [0.75, 0.25], [0.75, 0.75]]]), (10, 10), (20, 10) + ) + target = np.asarray([[[0.25, 0.375], [0.25, 0.625], [0.75, 0.375], [0.75, 0.625]]]) + assert np.all(pred == target) + + with pytest.raises(ValueError): + geometry.remap_boxes( + np.asarray([[[0.25, 0.25], [0.25, 0.75], [0.75, 0.25], [0.75, 0.75]]]), (80, 40, 150), (160, 40) + ) + + with pytest.raises(ValueError): + geometry.remap_boxes(np.asarray([[[0.25, 0.25], [0.25, 0.75], [0.75, 0.25], [0.75, 0.75]]]), (80, 40), (160,)) + + orig_dimension = (100, 100) + dest_dimensions = (200, 100) + # Unpack dimensions + height_o, width_o = orig_dimension + height_d, width_d = dest_dimensions + + orig_box = np.asarray([[[0.25, 0.25], [0.25, 0.25], [0.75, 0.75], [0.75, 0.75]]]) + + pred = geometry.remap_boxes(orig_box, orig_dimension, dest_dimensions) + + # Switch to absolute coords + orig = np.stack((orig_box[:, :, 0] * width_o, orig_box[:, :, 1] * height_o), axis=2)[0] + dest = np.stack((pred[:, :, 0] * width_d, pred[:, :, 1] * height_d), axis=2)[0] + + len_orig = hypot(orig[0][0] - orig[2][0], orig[0][1] - orig[2][1]) + len_dest = hypot(dest[0][0] - dest[2][0], dest[0][1] - dest[2][1]) + assert len_orig == len_dest + + alpha_orig = np.rad2deg(np.arctan((orig[0][1] - orig[2][1]) / (orig[0][0] - orig[2][0]))) + alpha_dest = np.rad2deg(np.arctan((dest[0][1] - dest[2][1]) / (dest[0][0] - dest[2][0]))) + assert alpha_orig == alpha_dest + + +def test_rotate_boxes(): + boxes = np.array([[0.1, 0.1, 0.8, 0.3, 0.5]]) + rboxes = np.array([[0.1, 0.1], [0.8, 0.1], [0.8, 0.3], [0.1, 0.3]]) + # Angle = 0 + rotated = geometry.rotate_boxes(boxes, angle=0.0, orig_shape=(1, 1)) + assert np.all(rotated == rboxes) + # Angle < 1: + rotated = geometry.rotate_boxes(boxes, angle=0.5, orig_shape=(1, 1)) + assert np.all(rotated == rboxes) + # Angle = 30 + rotated = geometry.rotate_boxes(boxes, angle=30, orig_shape=(1, 1)) + assert rotated.shape == (1, 4, 2) + + boxes = np.array([[0.0, 0.0, 0.6, 0.2, 0.5]]) + # Angle = -90: + rotated = geometry.rotate_boxes(boxes, angle=-90, orig_shape=(1, 1), min_angle=0) + assert np.allclose(rotated, np.array([[[1, 0.0], [1, 0.6], [0.8, 0.6], [0.8, 0.0]]])) + # Angle = 90 + rotated = geometry.rotate_boxes(boxes, angle=+90, orig_shape=(1, 1), min_angle=0) + assert np.allclose(rotated, np.array([[[0, 1.0], [0, 0.4], [0.2, 0.4], [0.2, 1.0]]])) + + +def test_rotate_image(): + img = np.ones((32, 64, 3), dtype=np.float32) + rotated = geometry.rotate_image(img, 30.0) + assert rotated.shape[:-1] == (32, 64) + assert rotated[0, 0, 0] == 0 + assert rotated[0, :, 0].sum() > 1 + + # Expand + rotated = geometry.rotate_image(img, 30.0, expand=True) + assert rotated.shape[:-1] == (60, 120) + assert rotated[0, :, 0].sum() <= 1 + + # Expand + rotated = geometry.rotate_image(img, 30.0, expand=True, preserve_origin_shape=True) + assert rotated.shape[:-1] == (32, 64) + assert rotated[0, :, 0].sum() <= 1 + + # Expand with 90° rotation + rotated = geometry.rotate_image(img, 90.0, expand=True) + assert rotated.shape[:-1] == (64, 128) + assert rotated[0, :, 0].sum() <= 1 + + +@pytest.mark.parametrize( + "abs_geoms, img_size, rel_geoms", + [ + # Full image (boxes) + [np.array([[0, 0, 32, 32]]), (32, 32), np.array([[0, 0, 1, 1]], dtype=np.float32)], + # Full image (polygons) + [ + np.array([[[0, 0], [32, 0], [32, 32], [0, 32]]]), + (32, 32), + np.array([[[0, 0], [1, 0], [1, 1], [0, 1]]], dtype=np.float32), + ], + # Quarter image (boxes) + [np.array([[0, 0, 16, 16]]), (32, 32), np.array([[0, 0, 0.5, 0.5]], dtype=np.float32)], + # Quarter image (polygons) + [ + np.array([[[0, 0], [16, 0], [16, 16], [0, 16]]]), + (32, 32), + np.array([[[0, 0], [0.5, 0], [0.5, 0.5], [0, 0.5]]], dtype=np.float32), + ], + ], +) +def test_convert_to_relative_coords(abs_geoms, img_size, rel_geoms): + assert np.all(geometry.convert_to_relative_coords(abs_geoms, img_size) == rel_geoms) + + # Wrong format + with pytest.raises(ValueError): + geometry.convert_to_relative_coords(np.zeros((3, 5)), (32, 32)) + + +def test_estimate_page_angle(): + straight_polys = np.array([ + [[0.3, 0.3], [0.4, 0.3], [0.4, 0.4], [0.3, 0.4]], + [[0.4, 0.4], [0.5, 0.4], [0.5, 0.5], [0.4, 0.5]], + [[0.5, 0.5], [0.6, 0.5], [0.6, 0.6], [0.5, 0.6]], + ]) + rotated_polys = geometry.rotate_boxes(straight_polys, angle=20, orig_shape=(512, 512)) + angle = geometry.estimate_page_angle(rotated_polys) + assert np.isclose(angle, 20) + # Test divide by zero / NaN + invalid_poly = np.array([[[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]]) + angle = geometry.estimate_page_angle(invalid_poly) + assert angle == 0.0 + + +def test_extract_crops(mock_pdf): + doc_img = DocumentFile.from_pdf(mock_pdf)[0] + num_crops = 2 + rel_boxes = np.array( + [[idx / num_crops, idx / num_crops, (idx + 1) / num_crops, (idx + 1) / num_crops] for idx in range(num_crops)], + dtype=np.float32, + ) + abs_boxes = np.array( + [ + [ + int(idx * doc_img.shape[1] / num_crops), + int(idx * doc_img.shape[0]) / num_crops, + int((idx + 1) * doc_img.shape[1] / num_crops), + int((idx + 1) * doc_img.shape[0] / num_crops), + ] + for idx in range(num_crops) + ], + dtype=np.float32, + ) + + with pytest.raises(AssertionError): + geometry.extract_crops(doc_img, np.zeros((1, 5))) + + for boxes in (rel_boxes, abs_boxes): + croped_imgs = geometry.extract_crops(doc_img, boxes) + # Number of crops + assert len(croped_imgs) == num_crops + # Data type and shape + assert all(isinstance(crop, np.ndarray) for crop in croped_imgs) + assert all(crop.ndim == 3 for crop in croped_imgs) + + # Identity + assert np.all( + doc_img == geometry.extract_crops(doc_img, np.array([[0, 0, 1, 1]], dtype=np.float32), channels_last=True)[0] + ) + torch_img = np.transpose(doc_img, axes=(-1, 0, 1)) + assert np.all( + torch_img + == np.transpose( + geometry.extract_crops(doc_img, np.array([[0, 0, 1, 1]], dtype=np.float32), channels_last=False)[0], + axes=(-1, 0, 1), + ) + ) + + # No box + assert geometry.extract_crops(doc_img, np.zeros((0, 4))) == [] + + +def test_extract_rcrops(mock_pdf): + doc_img = DocumentFile.from_pdf(mock_pdf)[0] + num_crops = 2 + rel_boxes = np.array( + [ + [ + [idx / num_crops, idx / num_crops], + [idx / num_crops + 0.1, idx / num_crops], + [idx / num_crops + 0.1, idx / num_crops + 0.1], + [idx / num_crops, idx / num_crops], + ] + for idx in range(num_crops) + ], + dtype=np.float32, + ) + abs_boxes = deepcopy(rel_boxes) + abs_boxes[:, :, 0] *= doc_img.shape[1] + abs_boxes[:, :, 1] *= doc_img.shape[0] + abs_boxes = abs_boxes.astype(np.int64) + + with pytest.raises(AssertionError): + geometry.extract_rcrops(doc_img, np.zeros((1, 8))) + for boxes in (rel_boxes, abs_boxes): + croped_imgs = geometry.extract_rcrops(doc_img, boxes) + # Number of crops + assert len(croped_imgs) == num_crops + # Data type and shape + assert all(isinstance(crop, np.ndarray) for crop in croped_imgs) + assert all(crop.ndim == 3 for crop in croped_imgs) + + # No box + assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2))) == [] diff --git a/tests/common/test_utils_multithreading.py b/tests/common/test_utils_multithreading.py new file mode 100644 index 0000000..8241eb9 --- /dev/null +++ b/tests/common/test_utils_multithreading.py @@ -0,0 +1,31 @@ +import os +from multiprocessing.pool import ThreadPool +from unittest.mock import patch + +import pytest + +from onnxtr.utils.multithreading import multithread_exec + + +@pytest.mark.parametrize( + "input_seq, func, output_seq", + [ + [[1, 2, 3], lambda x: 2 * x, [2, 4, 6]], + [[1, 2, 3], lambda x: x**2, [1, 4, 9]], + [ + ["this is", "show me", "I know"], + lambda x: x + " the way", + ["this is the way", "show me the way", "I know the way"], + ], + ], +) +def test_multithread_exec(input_seq, func, output_seq): + assert list(multithread_exec(func, input_seq)) == output_seq + assert list(multithread_exec(func, input_seq, 0)) == output_seq + + +@patch.dict(os.environ, {"ONNXTR_MULTIPROCESSING_DISABLE": "TRUE"}, clear=True) +def test_multithread_exec_multiprocessing_disable(): + with patch.object(ThreadPool, "map") as mock_tp_map: + multithread_exec(lambda x: x, [1, 2]) + assert not mock_tp_map.called diff --git a/tests/common/test_utils_reconstitution.py b/tests/common/test_utils_reconstitution.py new file mode 100644 index 0000000..20abefe --- /dev/null +++ b/tests/common/test_utils_reconstitution.py @@ -0,0 +1,12 @@ +import numpy as np +from test_io_elements import _mock_pages + +from onnxtr.utils import reconstitution + + +def test_synthesize_page(): + pages = _mock_pages() + reconstitution.synthesize_page(pages[0].export(), draw_proba=False) + render = reconstitution.synthesize_page(pages[0].export(), draw_proba=True) + assert isinstance(render, np.ndarray) + assert render.shape == (*pages[0].dimensions, 3) diff --git a/tests/common/test_utils_visualization.py b/tests/common/test_utils_visualization.py new file mode 100644 index 0000000..e291f34 --- /dev/null +++ b/tests/common/test_utils_visualization.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest +from test_io_elements import _mock_pages + +from onnxtr.utils import visualization + + +def test_visualize_page(): + pages = _mock_pages() + image = np.ones((300, 200, 3)) + visualization.visualize_page(pages[0].export(), image, words_only=False) + visualization.visualize_page(pages[0].export(), image, words_only=True, interactive=False) + # geometry checks + with pytest.raises(ValueError): + visualization.create_obj_patch([1, 2], (100, 100)) + + with pytest.raises(ValueError): + visualization.create_obj_patch((1, 2), (100, 100)) + + with pytest.raises(ValueError): + visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) + + +def test_draw_boxes(): + image = np.ones((256, 256, 3), dtype=np.float32) + boxes = [ + [0.1, 0.1, 0.2, 0.2], + [0.15, 0.15, 0.19, 0.2], # to suppress + [0.5, 0.5, 0.6, 0.55], + [0.55, 0.5, 0.7, 0.55], # to suppress + ] + visualization.draw_boxes(boxes=np.array(boxes), image=image, block=False) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..aca8889 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,132 @@ +from io import BytesIO + +import cv2 +import pytest +import requests +from PIL import Image, ImageDraw + +from onnxtr.io import reader +from onnxtr.utils import geometry +from onnxtr.utils.fonts import get_font + + +def synthesize_text_img( + text: str, + font_size: int = 32, + font_family=None, + background_color=None, + text_color=None, +) -> Image.Image: + background_color = (0, 0, 0) if background_color is None else background_color + text_color = (255, 255, 255) if text_color is None else text_color + + font = get_font(font_family, font_size) + left, top, right, bottom = font.getbbox(text) + text_w, text_h = right - left, bottom - top + h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w)) + # If single letter, make the image square, otherwise expand to meet the text size + img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w)) + + img = Image.new("RGB", img_size[::-1], color=background_color) + d = ImageDraw.Draw(img) + + # Offset so that the text is centered + text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2))) + # Draw the text + d.text(text_pos, text, font=font, fill=text_color) + return img + + +@pytest.fixture(scope="session") +def mock_vocab(): + return ( + "3K}7eé;5àÎYho]QwV6qU~W\"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£|za1ù8,OG€P-kçHëÀÂ2É/ûIJ'j" + "(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l" + ) + + +@pytest.fixture(scope="session") +def mock_pdf(tmpdir_factory): + # Page 1 + text_img = synthesize_text_img("I am a jedi!", background_color=(255, 255, 255), text_color=(0, 0, 0)) + page = Image.new(text_img.mode, (1240, 1754), (255, 255, 255)) + page.paste(text_img, (50, 100)) + + # Page 2 + text_img = synthesize_text_img("No, I am your father.", background_color=(255, 255, 255), text_color=(0, 0, 0)) + _page = Image.new(text_img.mode, (1240, 1754), (255, 255, 255)) + _page.paste(text_img, (40, 300)) + + # Save the PDF + fn = tmpdir_factory.mktemp("data").join("mock_pdf_file.pdf") + page.save(str(fn), "PDF", save_all=True, append_images=[_page]) + + return str(fn) + + +@pytest.fixture(scope="session") +def mock_payslip(tmpdir_factory): + url = "https://3.bp.blogspot.com/-Es0oHTCrVEk/UnYA-iW9rYI/AAAAAAAAAFI/hWExrXFbo9U/s1600/003.jpg" + file = BytesIO(requests.get(url).content) + folder = tmpdir_factory.mktemp("data") + fn = str(folder.join("mock_payslip.jpeg")) + with open(fn, "wb") as f: + f.write(file.getbuffer()) + return fn + + +@pytest.fixture(scope="session") +def mock_tilted_payslip(mock_payslip, tmpdir_factory): + image = reader.read_img_as_numpy(mock_payslip) + image = geometry.rotate_image(image, 30, expand=True) + tmp_path = str(tmpdir_factory.mktemp("data").join("mock_tilted_payslip.jpg")) + cv2.imwrite(tmp_path, image) + return tmp_path + + +@pytest.fixture(scope="session") +def mock_text_box_stream(): + url = "https://doctr-static.mindee.com/models?id=v0.5.1/word-crop.png&src=0" + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_text_box(mock_text_box_stream, tmpdir_factory): + file = BytesIO(mock_text_box_stream) + fn = tmpdir_factory.mktemp("data").join("mock_text_box_file.png") + with open(fn, "wb") as f: + f.write(file.getbuffer()) + return str(fn) + + +@pytest.fixture(scope="session") +def mock_image_stream(): + url = "https://miro.medium.com/max/3349/1*mk1-6aYaf_Bes1E3Imhc0A.jpeg" + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_artefact_image_stream(): + url = "https://github.com/mindee/doctr/releases/download/v0.8.1/artefact_dummy.jpg" + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_image_path(mock_image_stream, tmpdir_factory): + file = BytesIO(mock_image_stream) + folder = tmpdir_factory.mktemp("images") + fn = folder.join("mock_image_file.jpeg") + with open(fn, "wb") as f: + f.write(file.getbuffer()) + return str(fn) + + +@pytest.fixture(scope="session") +def mock_image_folder(mock_image_stream, tmpdir_factory): + file = BytesIO(mock_image_stream) + folder = tmpdir_factory.mktemp("images") + for i in range(5): + fn = folder.join("mock_image_file_" + str(i) + ".jpeg") + with open(fn, "wb") as f: + f.write(file.getbuffer()) + return str(folder) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py new file mode 100644 index 0000000..09e46d9 --- /dev/null +++ b/tests/pytorch/test_datasets_pt.py @@ -0,0 +1,623 @@ +import os +from shutil import move + +import numpy as np +import pytest +import torch +from doctr import datasets +from doctr.file_utils import CLASS_NAME +from doctr.transforms import Resize +from torch.utils.data import DataLoader, RandomSampler + + +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): + # Fetch one sample + img, target = ds[0] + + assert isinstance(img, torch.Tensor) + assert img.shape == (3, *input_size) + assert img.dtype == torch.float32 + assert isinstance(target, dict) + assert isinstance(target["boxes"], np.ndarray) and target["boxes"].dtype == np.float32 + if is_polygons: + assert target["boxes"].ndim == 3 and target["boxes"].shape[1:] == (4, 2) + else: + assert target["boxes"].ndim == 2 and target["boxes"].shape[1:] == (4,) + assert np.all(np.logical_and(target["boxes"] <= 1, target["boxes"] >= 0)) + if class_indices: + assert isinstance(target["labels"], np.ndarray) and target["labels"].dtype == np.int64 + else: + assert isinstance(target["labels"], list) and all(isinstance(s, str) for s in target["labels"]) + assert len(target["labels"]) == len(target["boxes"]) + + # Check batching + loader = DataLoader( + ds, + batch_size=batch_size, + drop_last=True, + sampler=RandomSampler(ds), + num_workers=0, + pin_memory=True, + collate_fn=ds.collate_fn, + ) + + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + + +def _validate_dataset_recognition_part(ds, input_size, batch_size=2): + # Fetch one sample + img, label = ds[0] + + assert isinstance(img, torch.Tensor) + assert img.shape == (3, *input_size) + assert img.dtype == torch.float32 + assert isinstance(label, str) + + # Check batching + loader = DataLoader( + ds, + batch_size=batch_size, + drop_last=True, + sampler=RandomSampler(ds), + num_workers=0, + pin_memory=True, + collate_fn=ds.collate_fn, + ) + + images, labels = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + +def test_visiondataset(): + url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" + with pytest.raises(ValueError): + datasets.datasets.VisionDataset(url, download=False) + + dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True) + assert len(dataset) == 0 + assert repr(dataset) == "VisionDataset()" + + +def test_rotation_dataset(mock_image_folder): + input_size = (1024, 1024) + + ds = datasets.OrientationDataset(img_folder=mock_image_folder, img_transforms=Resize(input_size)) + assert len(ds) == 5 + img, target = ds[0] + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape[-2:] == input_size + # Prefilled rotation targets + assert isinstance(target, np.ndarray) and target.dtype == np.int64 + # check that all prefilled targets are 0 (degrees) + assert np.all(target == 0) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + +def test_detection_dataset(mock_image_folder, mock_detection_label): + input_size = (1024, 1024) + + ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 5 + img, target_dict = ds[0] + target = target_dict[CLASS_NAME] + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape[-2:] == input_size + # Bounding boxes + assert isinstance(target_dict, dict) + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) + assert target.shape[1] == 4 + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, list) and all( + isinstance(elt, np.ndarray) for target in targets for elt in target.values() + ) + # Rotated DS + rotated_ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + use_polygons=True, + ) + _, r_target = rotated_ds[0] + assert r_target[CLASS_NAME].shape[1:] == (4, 2) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_recognition_dataset(mock_image_folder, mock_recognition_label): + input_size = (32, 128) + ds = datasets.RecognitionDataset( + img_folder=mock_image_folder, + labels_path=mock_recognition_label, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + assert len(ds) == 5 + image, label = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(label, str) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, labels = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +@pytest.mark.parametrize( + "use_polygons", + [False, True], +) +def test_ocrdataset(mock_ocrdataset, use_polygons): + input_size = (512, 512) + + ds = datasets.OCRDataset( + *mock_ocrdataset, + img_transforms=Resize(input_size), + use_polygons=use_polygons, + ) + + assert len(ds) == 3 + _validate_dataset(ds, input_size, is_polygons=use_polygons) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_charactergenerator(): + input_size = (32, 32) + vocab = "abcdef" + + ds = datasets.CharacterGenerator( + vocab=vocab, + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, label = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(label, int) and label < len(vocab) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, torch.Tensor) and targets.shape == (2,) + assert targets.dtype == torch.int64 + + +def test_wordgenerator(): + input_size = (32, 128) + wordlen_range = (1, 10) + vocab = "abcdef" + + ds = datasets.WordGenerator( + vocab=vocab, + min_chars=wordlen_range[0], + max_chars=wordlen_range[1], + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, target = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1] + assert all(char in vocab for char in target) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples", + [ + [[512, 512], 3], # Actual set has 2700 training samples and 300 test samples + ], +) +def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts): + # monkeypatch the path to temporary dataset + datasets.DocArtefacts.URL = mock_doc_artefacts + datasets.DocArtefacts.SHA256 = None + + ds = datasets.DocArtefacts( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + cache_dir="/".join(mock_doc_artefacts.split("/")[:-2]), + cache_subdir=mock_doc_artefacts.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"DocArtefacts(train={True})" + _validate_dataset(ds, input_size, class_indices=True, is_polygons=rotate) + + +# NOTE: following datasets support also recognition task + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True], # recognition + ], +) +def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): + # monkeypatch the path to temporary dataset + datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") + + ds = datasets.SROIE( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SROIE(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 5, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True], # recognition + ], +) +def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): + ds = datasets.IC13( + *mock_ic13, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + + assert len(ds) == num_samples + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True], # recognition + ], +) +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): + ds = datasets.IMGUR5K( + *mock_imgur5k, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + + assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split + assert repr(ds) == f"IMGUR5K(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True], # recognition + ], +) +def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): + # monkeypatch the path to temporary dataset + datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") + + ds = datasets.SVHN( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVHN(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True], # recognition + ], +) +def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): + # monkeypatch the path to temporary dataset + datasets.FUNSD.URL = mock_funsd_dataset + datasets.FUNSD.SHA256 = None + datasets.FUNSD.FILE_NAME = "funsd.zip" + + ds = datasets.FUNSD( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"FUNSD(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True], # recognition + ], +) +def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): + # monkeypatch the path to temporary dataset + datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") + + ds = datasets.CORD( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"CORD(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True], # recognition + ], +) +def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): + # monkeypatch the path to temporary dataset + datasets.SynthText.URL = mock_synthtext_dataset + datasets.SynthText.SHA256 = None + + ds = datasets.SynthText( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SynthText(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True], # recognition + ], +) +def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): + # monkeypatch the path to temporary dataset + datasets.IIIT5K.URL = mock_iiit5k_dataset + datasets.IIIT5K.SHA256 = None + + ds = datasets.IIIT5K( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IIIT5K(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size, batch_size=1) + else: + _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True], # recognition + ], +) +def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): + # monkeypatch the path to temporary dataset + datasets.SVT.URL = mock_svt_dataset + datasets.SVT.SHA256 = None + + ds = datasets.SVT( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVT(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True], # recognition + ], +) +def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): + # monkeypatch the path to temporary dataset + datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") + + ds = datasets.IC03( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IC03(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 2, False], + [[32, 128], 5, True], + ], +) +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): + ds = datasets.WILDRECEIPT( + *mock_wildreceipt_dataset, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + assert len(ds) == num_samples + assert repr(ds) == f"WILDRECEIPT(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size) + + +def test_iiithws_dataset(mock_iiithws_dataset): + input_size = (32, 128) + ds = datasets.IIITHWS( + *mock_iiithws_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7141797 train and 793533 test samples + assert repr(ds) == f"IIITHWS(train={True})" + _validate_dataset_recognition_part(ds, input_size) diff --git a/tests/pytorch/test_file_utils_pt.py b/tests/pytorch/test_file_utils_pt.py new file mode 100644 index 0000000..7b36789 --- /dev/null +++ b/tests/pytorch/test_file_utils_pt.py @@ -0,0 +1,5 @@ +from doctr.file_utils import is_torch_available + + +def test_file_utils(): + assert is_torch_available() diff --git a/tests/pytorch/test_io_image_pt.py b/tests/pytorch/test_io_image_pt.py new file mode 100644 index 0000000..ad8a44b --- /dev/null +++ b/tests/pytorch/test_io_image_pt.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +import torch +from doctr.io import decode_img_as_tensor, read_img_as_tensor, tensor_from_numpy + + +def test_read_img_as_tensor(mock_image_path): + img = read_img_as_tensor(mock_image_path) + + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape == (3, 900, 1200) + + img = read_img_as_tensor(mock_image_path, dtype=torch.float16) + assert img.dtype == torch.float16 + img = read_img_as_tensor(mock_image_path, dtype=torch.uint8) + assert img.dtype == torch.uint8 + + with pytest.raises(ValueError): + read_img_as_tensor(mock_image_path, dtype=torch.float64) + + +def test_decode_img_as_tensor(mock_image_stream): + img = decode_img_as_tensor(mock_image_stream) + + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape == (3, 900, 1200) + + img = decode_img_as_tensor(mock_image_stream, dtype=torch.float16) + assert img.dtype == torch.float16 + img = decode_img_as_tensor(mock_image_stream, dtype=torch.uint8) + assert img.dtype == torch.uint8 + + with pytest.raises(ValueError): + decode_img_as_tensor(mock_image_stream, dtype=torch.float64) + + +def test_tensor_from_numpy(mock_image_stream): + with pytest.raises(ValueError): + tensor_from_numpy(np.zeros((256, 256, 3)), torch.int64) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8)) + + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.float32 + assert out.shape == (3, 256, 256) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.float16) + assert out.dtype == torch.float16 + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.uint8) + assert out.dtype == torch.uint8 diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py new file mode 100644 index 0000000..ca3e7e2 --- /dev/null +++ b/tests/pytorch/test_models_classification_pt.py @@ -0,0 +1,194 @@ +import os +import tempfile + +import cv2 +import numpy as np +import onnxruntime +import pytest +import torch +from doctr.models import classification +from doctr.models.classification.predictor import OrientationPredictor +from doctr.models.utils import export_model_to_onnx + + +def _test_classification(model, input_shape, output_size, batch_size=2): + # Forward + with torch.no_grad(): + out = model(torch.rand((batch_size, *input_shape), dtype=torch.float32)) + # Output checks + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.float32 + assert out.numpy().shape == (batch_size, *output_size) + # Check FP16 + if torch.cuda.is_available(): + model = model.half().cuda() + with torch.no_grad(): + out = model(torch.rand((batch_size, *input_shape), dtype=torch.float16).cuda()) + assert out.dtype == torch.float16 + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (3, 32, 32), (126,)], + ["resnet18", (3, 32, 32), (126,)], + ["resnet31", (3, 32, 32), (126,)], + ["resnet34", (3, 32, 32), (126,)], + ["resnet34_wide", (3, 32, 32), (126,)], + ["resnet50", (3, 32, 32), (126,)], + ["magc_resnet31", (3, 32, 32), (126,)], + ["mobilenet_v3_small", (3, 32, 32), (126,)], + ["mobilenet_v3_large", (3, 32, 32), (126,)], + ["textnet_tiny", (3, 32, 32), (126,)], + ["textnet_small", (3, 32, 32), (126,)], + ["textnet_base", (3, 32, 32), (126,)], + ["vit_s", (3, 32, 32), (126,)], + ["vit_b", (3, 32, 32), (126,)], + # Check that the interpolation of positional embeddings for vit models works correctly + ["vit_s", (3, 64, 64), (126,)], + ], +) +def test_classification_architectures(arch_name, input_shape, output_size): + # Model + model = classification.__dict__[arch_name](pretrained=True).eval() + _test_classification(model, input_shape, output_size) + # Check that you can pretrained everything up until the last layer + classification.__dict__[arch_name](pretrained=True, num_classes=10) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["mobilenet_v3_small_crop_orientation", (3, 256, 256)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512)], + ], +) +def test_classification_models(arch_name, input_shape): + batch_size = 8 + model = classification.__dict__[arch_name](pretrained=False, input_shape=input_shape).eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor) + assert isinstance(out, torch.Tensor) + assert out.shape == (8, 4) + + +@pytest.mark.parametrize( + "arch_name", + [ + "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", + ], +) +def test_classification_zoo(arch_name): + if "crop" in arch_name: + batch_size = 16 + input_tensor = torch.rand((batch_size, 3, 256, 256)) + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + predictor.model.eval() + + with pytest.raises(ValueError): + predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + else: + batch_size = 2 + input_tensor = torch.rand((batch_size, 3, 512, 512)) + # Model + predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False) + predictor.model.eval() + + with pytest.raises(ValueError): + predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False) + # object check + assert isinstance(predictor, OrientationPredictor) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + out = predictor(input_tensor) + class_idxs, classes, confs = out[0], out[1], out[2] + assert isinstance(class_idxs, list) and len(class_idxs) == batch_size + assert isinstance(classes, list) and len(classes) == batch_size + assert isinstance(confs, list) and len(confs) == batch_size + assert all(isinstance(pred, int) for pred in class_idxs) + assert all(isinstance(pred, int) for pred in classes) and all(pred in [0, 90, 180, -90] for pred in classes) + assert all(isinstance(pred, float) for pred in confs) + + +def test_crop_orientation_model(mock_text_box): + text_box_0 = cv2.imread(mock_text_box) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + + +def test_page_orientation_model(mock_payslip): + text_box_0 = cv2.imread(mock_payslip) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (3, 32, 32), (126,)], + ["resnet18", (3, 32, 32), (126,)], + ["resnet31", (3, 32, 32), (126,)], + ["resnet34", (3, 32, 32), (126,)], + ["resnet34_wide", (3, 32, 32), (126,)], + ["resnet50", (3, 32, 32), (126,)], + ["magc_resnet31", (3, 32, 32), (126,)], + ["mobilenet_v3_small", (3, 32, 32), (126,)], + ["mobilenet_v3_large", (3, 32, 32), (126,)], + ["mobilenet_v3_small_crop_orientation", (3, 256, 256), (4,)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512), (4,)], + ["vit_s", (3, 32, 32), (126,)], + ["vit_b", (3, 32, 32), (126,)], + ["textnet_tiny", (3, 32, 32), (126,)], + ["textnet_small", (3, 32, 32), (126,)], + ["textnet_base", (3, 32, 32), (126,)], + ], +) +def test_models_onnx_export(arch_name, input_shape, output_size): + # Model + batch_size = 2 + model = classification.__dict__[arch_name](pretrained=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input).detach().cpu().numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path = export_model_to_onnx(model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input) + + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(["logits"], {"input": dummy_input.numpy()}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == (batch_size, *output_size) + # Check that the output is close to the PyTorch output - only warn if not close + try: + assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py new file mode 100644 index 0000000..26c7a63 --- /dev/null +++ b/tests/pytorch/test_models_detection_pt.py @@ -0,0 +1,187 @@ +import math +import os +import tempfile + +import numpy as np +import onnxruntime +import pytest +import torch +from doctr.file_utils import CLASS_NAME +from doctr.models import detection +from doctr.models.detection._utils import dilate, erode +from doctr.models.detection.fast.pytorch import reparameterize +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.utils import export_model_to_onnx + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape, output_size, out_prob", + [ + ["db_resnet34", (3, 512, 512), (1, 512, 512), True], + ["db_resnet50", (3, 512, 512), (1, 512, 512), True], + ["db_mobilenet_v3_large", (3, 512, 512), (1, 512, 512), True], + ["linknet_resnet18", (3, 512, 512), (1, 512, 512), True], + ["linknet_resnet34", (3, 512, 512), (1, 512, 512), True], + ["linknet_resnet50", (3, 512, 512), (1, 512, 512), True], + ["fast_tiny", (3, 512, 512), (1, 512, 512), True], + ["fast_tiny_rep", (3, 512, 512), (1, 512, 512), True], # Reparameterized model + ["fast_small", (3, 512, 512), (1, 512, 512), True], + ["fast_base", (3, 512, 512), (1, 512, 512), True], + ], +) +def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): + batch_size = 2 + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True).eval()) + train_mode = False # Reparameterized model is not trainable + else: + model = detection.__dict__[arch_name](pretrained=True) + model = model.train() if train_mode else model.eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + target = [ + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32)}, + ] + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor, target, return_model_output=True, return_preds=not train_mode) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + # Check proba map + assert out["out_map"].shape == (batch_size, *output_size) + assert out["out_map"].dtype == torch.float32 + if out_prob: + assert torch.all((out["out_map"] >= 0) & (out["out_map"] <= 1)) + # Check boxes + if not train_mode: + for boxes_dict in out["preds"]: + for boxes in boxes_dict.values(): + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + # Check loss + assert isinstance(out["loss"], torch.Tensor) + # Check the rotated case (same targets) + target = [ + { + CLASS_NAME: np.array( + [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.8], [0.5, 0.8]]], + dtype=np.float32, + ) + }, + { + CLASS_NAME: np.array( + [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.9], [0.5, 0.9]]], + dtype=np.float32, + ) + }, + ] + loss = model(input_tensor, target)["loss"] + assert isinstance(loss, torch.Tensor) and ((loss - out["loss"]).abs() / loss).item() < 1 + + +@pytest.mark.parametrize( + "arch_name", + [ + "db_resnet34", + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "fast_tiny", + ], +) +def test_detection_zoo(arch_name): + # Model + predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, DetectionPredictor) + input_tensor = torch.rand((2, 3, 1024, 1024)) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out, seq_maps = predictor(input_tensor, return_maps=True) + assert all(isinstance(boxes, dict) for boxes in out) + assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) + assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps) + assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps) + # check that all values in the seq_maps are between 0 and 1 + assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps) + + +def test_fast_reparameterization(): + dummy_input = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + base_model = detection.fast_tiny(pretrained=True, exportable=True).eval() + base_model_params = sum(p.numel() for p in base_model.parameters()) + assert math.isclose(base_model_params, 13535296) # base model params + base_out = base_model(dummy_input)["logits"] + rep_model = reparameterize(base_model) + rep_model_params = sum(p.numel() for p in rep_model.parameters()) + assert math.isclose(rep_model_params, 8521920) # reparameterized model params + rep_out = rep_model(dummy_input)["logits"] + diff = base_out - rep_out + assert diff.mean() < 5e-2 + + +def test_erode(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.zeros((1, 1, 3, 3)) + out = erode(x, 3) + assert torch.equal(out, expected) + + +def test_dilate(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.ones((1, 1, 3, 3)) + out = dilate(x, 3) + assert torch.equal(out, expected) + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["db_resnet34", (3, 512, 512), (1, 512, 512)], + ["db_resnet50", (3, 512, 512), (1, 512, 512)], + ["db_mobilenet_v3_large", (3, 512, 512), (1, 512, 512)], + ["linknet_resnet18", (3, 512, 512), (1, 512, 512)], + ["linknet_resnet34", (3, 512, 512), (1, 512, 512)], + ["linknet_resnet50", (3, 512, 512), (1, 512, 512)], + ["fast_tiny", (3, 512, 512), (1, 512, 512)], + ["fast_small", (3, 512, 512), (1, 512, 512)], + ["fast_base", (3, 512, 512), (1, 512, 512)], + ["fast_tiny_rep", (3, 512, 512), (1, 512, 512)], # Reparameterized model + ], +) +def test_models_onnx_export(arch_name, input_shape, output_size): + # Model + batch_size = 2 + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True).eval()) + else: + model = detection.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path = export_model_to_onnx(model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(["logits"], {"input": dummy_input.numpy()}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == (batch_size, *output_size) + # Check that the output is close to the PyTorch output - only warn if not close + try: + assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py new file mode 100644 index 0000000..5c75582 --- /dev/null +++ b/tests/pytorch/test_models_factory.py @@ -0,0 +1,69 @@ +import json +import os +import tempfile + +import pytest +from doctr import models +from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub + + +def test_push_to_hf_hub(): + model = models.classification.resnet18(pretrained=False) + with pytest.raises(ValueError): + # run_config and/or arch must be specified + push_to_hf_hub(model, model_name="test", task="classification") + with pytest.raises(ValueError): + # task must be one of classification, detection, recognition, obj_detection + push_to_hf_hub(model, model_name="test", task="invalid_task", arch="mobilenet_v3_small") + with pytest.raises(ValueError): + # arch not in available architectures for task + push_to_hf_hub(model, model_name="test", task="detection", arch="crnn_mobilenet_v3_large") + + +@pytest.mark.parametrize( + "arch_name, task_name, dummy_model_id", + [ + ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-torch-vgg16-bn-r"], + ["resnet18", "classification", "Felix92/doctr-dummy-torch-resnet18"], + ["resnet31", "classification", "Felix92/doctr-dummy-torch-resnet31"], + ["resnet34", "classification", "Felix92/doctr-dummy-torch-resnet34"], + ["resnet34_wide", "classification", "Felix92/doctr-dummy-torch-resnet34-wide"], + ["resnet50", "classification", "Felix92/doctr-dummy-torch-resnet50"], + ["magc_resnet31", "classification", "Felix92/doctr-dummy-torch-magc-resnet31"], + ["mobilenet_v3_small", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-small"], + ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-large"], + ["vit_s", "classification", "Felix92/doctr-dummy-torch-vit-s"], + ["textnet_tiny", "classification", "Felix92/doctr-dummy-torch-textnet-tiny"], + ["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"], + ["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"], + ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-torch-db-mobilenet-v3-large"], + ["linknet_resnet18", "detection", "Felix92/doctr-dummy-torch-linknet-resnet18"], + ["linknet_resnet34", "detection", "Felix92/doctr-dummy-torch-linknet-resnet34"], + ["linknet_resnet50", "detection", "Felix92/doctr-dummy-torch-linknet-resnet50"], + ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-torch-crnn-vgg16-bn"], + ["crnn_mobilenet_v3_small", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-small"], + ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-large"], + ["sar_resnet31", "recognition", "Felix92/doctr-dummy-torch-sar-resnet31"], + ["master", "recognition", "Felix92/doctr-dummy-torch-master"], + ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], + ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], + ], +) +def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): + with tempfile.TemporaryDirectory() as tmp_dir: + model = models.__dict__[task_name].__dict__[arch_name](pretrained=True).eval() + + _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) + + assert hasattr(model, "cfg") + assert len(os.listdir(tmp_dir)) == 2 + assert os.path.exists(tmp_dir + "/pytorch_model.bin") + assert os.path.exists(tmp_dir + "/config.json") + tmp_config = json.load(open(tmp_dir + "/config.json")) + assert arch_name == tmp_config["arch"] + assert task_name == tmp_config["task"] + assert all(key in model.cfg.keys() for key in tmp_config.keys()) + + # test from hub + hub_model = from_hub(repo_id=dummy_model_id) + assert isinstance(hub_model, type(model)) diff --git a/tests/pytorch/test_models_preprocessor_pt.py b/tests/pytorch/test_models_preprocessor_pt.py new file mode 100644 index 0000000..e3e2983 --- /dev/null +++ b/tests/pytorch/test_models_preprocessor_pt.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest +import torch +from doctr.models.preprocessor import PreProcessor + + +@pytest.mark.parametrize( + "batch_size, output_size, input_tensor, expected_batches, expected_value", + [ + [2, (128, 128), np.full((3, 256, 128, 3), 255, dtype=np.uint8), 1, 0.5], # numpy uint8 + [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float32), 1, 0.5], # numpy fp32 + [2, (128, 128), torch.full((3, 3, 256, 128), 255, dtype=torch.uint8), 1, 0.5], # torch uint8 + [2, (128, 128), torch.ones((3, 3, 256, 128), dtype=torch.float32), 1, 0.5], # torch fp32 + [2, (128, 128), torch.ones((3, 3, 256, 128), dtype=torch.float16), 1, 0.5], # torch fp16 + [2, (128, 128), [np.full((256, 128, 3), 255, dtype=np.uint8)] * 3, 2, 0.5], # list of numpy uint8 + [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float32)] * 3, 2, 0.5], # list of numpy fp32 + [2, (128, 128), [torch.full((3, 256, 128), 255, dtype=torch.uint8)] * 3, 2, 0.5], # list of torch uint8 + [2, (128, 128), [torch.ones((3, 256, 128), dtype=torch.float32)] * 3, 2, 0.5], # list of torch fp32 + [2, (128, 128), [torch.ones((3, 256, 128), dtype=torch.float16)] * 3, 2, 0.5], # list of torch fp32 + ], +) +def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, expected_value): + processor = PreProcessor(output_size, batch_size) + + # Invalid input type + with pytest.raises(TypeError): + processor(42) + # 4D check + with pytest.raises(AssertionError): + processor(np.full((256, 128, 3), 255, dtype=np.uint8)) + with pytest.raises(TypeError): + processor(np.full((1, 256, 128, 3), 255, dtype=np.int32)) + # 3D check + with pytest.raises(AssertionError): + processor([np.full((3, 256, 128, 3), 255, dtype=np.uint8)]) + with pytest.raises(TypeError): + processor([np.full((256, 128, 3), 255, dtype=np.int32)]) + + with torch.no_grad(): + out = processor(input_tensor) + assert isinstance(out, list) and len(out) == expected_batches + assert all(isinstance(b, torch.Tensor) for b in out) + assert all(b.dtype == torch.float32 for b in out) + assert all(b.shape[-2:] == output_size for b in out) + assert all(torch.all(b == expected_value) for b in out) + assert len(repr(processor).split("\n")) == 4 diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py new file mode 100644 index 0000000..64a0d70 --- /dev/null +++ b/tests/pytorch/test_models_recognition_pt.py @@ -0,0 +1,155 @@ +import os +import tempfile + +import numpy as np +import onnxruntime +import psutil +import pytest +import torch +from doctr.models import recognition +from doctr.models.recognition.crnn.pytorch import CTCPostProcessor +from doctr.models.recognition.master.pytorch import MASTERPostProcessor +from doctr.models.recognition.parseq.pytorch import PARSeqPostProcessor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.sar.pytorch import SARPostProcessor +from doctr.models.recognition.vitstr.pytorch import ViTSTRPostProcessor +from doctr.models.utils import export_model_to_onnx + +system_available_memory = int(psutil.virtual_memory().available / 1024**3) + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (3, 32, 128)], + ["crnn_mobilenet_v3_small", (3, 32, 128)], + ["crnn_mobilenet_v3_large", (3, 32, 128)], + ["sar_resnet31", (3, 32, 128)], + ["master", (3, 32, 128)], + ["vitstr_small", (3, 32, 128)], + ["vitstr_base", (3, 32, 128)], + ["parseq", (3, 32, 128)], + ], +) +def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): + batch_size = 4 + model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=True, input_shape=input_shape) + model = model.train() if train_mode else model.eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + target = ["i", "am", "a", "jedi"] + + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor, target, return_model_output=True, return_preds=not train_mode) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + if not train_mode: + assert isinstance(out["preds"], list) + assert len(out["preds"]) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in out["preds"]) + assert isinstance(out["out_map"], torch.Tensor) + assert out["out_map"].dtype == torch.float32 + assert isinstance(out["loss"], torch.Tensor) + # test model in train mode needs targets + with pytest.raises(ValueError): + model.train() + model(input_tensor, None) + + +@pytest.mark.parametrize( + "post_processor, input_shape", + [ + [CTCPostProcessor, [2, 119, 30]], + [SARPostProcessor, [2, 119, 30]], + [ViTSTRPostProcessor, [2, 119, 30]], + [MASTERPostProcessor, [2, 119, 30]], + [PARSeqPostProcessor, [2, 119, 30]], + ], +) +def test_reco_postprocessors(post_processor, input_shape, mock_vocab): + processor = post_processor(mock_vocab) + decoded = processor(torch.rand(*input_shape)) + assert isinstance(decoded, list) + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in decoded) + assert len(decoded) == input_shape[0] + assert all(char in mock_vocab for word, _ in decoded for char in word) + # Repr + assert repr(processor) == f"{post_processor.__name__}(vocab_size={len(mock_vocab)})" + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", + ], +) +def test_recognition_zoo(arch_name): + batch_size = 2 + # Model + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, RecognitionPredictor) + input_tensor = torch.rand((batch_size, 3, 128, 128)) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) for word, conf in out) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (3, 32, 128)], + ["crnn_mobilenet_v3_small", (3, 32, 128)], + ["crnn_mobilenet_v3_large", (3, 32, 128)], + pytest.param( + "sar_resnet31", + (3, 32, 128), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "master", (3, 32, 128), marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory") + ), + ["vitstr_small", (3, 32, 128)], # testing one vitstr version is enough + ["parseq", (3, 32, 128)], + ], +) +def test_models_onnx_export(arch_name, input_shape): + # Model + batch_size = 2 + model = recognition.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path = export_model_to_onnx(model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(["logits"], {"input": dummy_input.numpy()}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == pt_logits.shape + # Check that the output is close to the PyTorch output - only warn if not close + try: + assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py new file mode 100644 index 0000000..122978a --- /dev/null +++ b/tests/pytorch/test_models_utils_pt.py @@ -0,0 +1,65 @@ +import os + +import pytest +import torch +from doctr.models.utils import ( + _bf16_to_float32, + _copy_tensor, + conv_sequence_pt, + load_pretrained_params, + set_device_and_dtype, +) +from torch import nn + + +def test_copy_tensor(): + x = torch.rand(8) + m = _copy_tensor(x) + assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and torch.allclose(m, x) + + +def test_bf16_to_float32(): + x = torch.randn([2, 2], dtype=torch.bfloat16) + converted_x = _bf16_to_float32(x) + assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float()) + + +def test_load_pretrained_params(tmpdir_factory): + model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) + # Retrieve this URL + url = "https://github.com/mindee/doctr/releases/download/v0.2.1/tmp_checkpoint-6f0ce0e6.pt" + # Temp cache dir + cache_dir = tmpdir_factory.mktemp("cache") + # Pass an incorrect hash + with pytest.raises(ValueError): + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) + # Let it resolve the hash from the file name + load_pretrained_params(model, url, cache_dir=str(cache_dir)) + # Check that the file was downloaded & the archive extracted + assert os.path.exists(cache_dir.join("models").join(url.rpartition("/")[-1].split("&")[0])) + # Check ignore keys + load_pretrained_params(model, url, cache_dir=str(cache_dir), ignore_keys=["2.weight"]) + # non matching keys + model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 1)) + with pytest.raises(ValueError): + load_pretrained_params(model, url, cache_dir=str(cache_dir), ignore_keys=["2.weight"]) + + +def test_conv_sequence(): + assert len(conv_sequence_pt(3, 8, kernel_size=3)) == 1 + assert len(conv_sequence_pt(3, 8, True, kernel_size=3)) == 2 + assert len(conv_sequence_pt(3, 8, False, True, kernel_size=3)) == 2 + assert len(conv_sequence_pt(3, 8, True, True, kernel_size=3)) == 3 + + +def test_set_device_and_dtype(): + model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) + batches = [torch.rand(8) for _ in range(2)] + model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float32) + assert model[0].weight.device == torch.device("cpu") + assert model[0].weight.dtype == torch.float32 + assert batches[0].device == torch.device("cpu") + assert batches[0].dtype == torch.float32 + model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16) + assert model[0].weight.dtype == torch.float16 + assert batches[0].dtype == torch.float16 diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py new file mode 100644 index 0000000..2ff3cf4 --- /dev/null +++ b/tests/pytorch/test_models_zoo_pt.py @@ -0,0 +1,327 @@ +import numpy as np +import pytest +from doctr import models +from doctr.file_utils import CLASS_NAME +from doctr.io import Document, DocumentFile +from doctr.io.elements import KIEDocument +from doctr.models import detection, recognition +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.detection.zoo import detection_predictor +from doctr.models.kie_predictor import KIEPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.zoo import recognition_predictor +from torch import nn + + +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=False, + pretrained_backbone=False, + assume_straight_pages=assume_straight_pages, + ), + ) + + assert not det_predictor.model.training + + reco_bsize = 32 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + assert not reco_predictor.model.training + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, nn.Module) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + + +def test_trained_ocr_predictor(mock_payslip): + doc = DocumentFile.from_images(mock_payslip) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + symmetric_pad=True, + preserve_aspect_ratio=False, + ) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) + assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr, rtol=0.05) + + assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised" + geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) + assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + # test hooks + predictor.add_hook(_DummyCallback()) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=False, + pretrained_backbone=False, + assume_straight_pages=assume_straight_pages, + ), + ) + + assert not det_predictor.model.training + + reco_bsize = 32 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + assert not reco_predictor.model.training + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, nn.Module) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + + +def test_trained_kie_predictor(mock_payslip): + doc = DocumentFile.from_images(mock_payslip) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + symmetric_pad=True, + preserve_aspect_ratio=False, + ) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + # test hooks + predictor.add_hook(_DummyCallback()) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr, rtol=0.05) + + assert out.pages[0].predictions[CLASS_NAME][4].value == "revised" + geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][4].geometry), geometry_revised, rtol=0.05) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + + +def _test_predictor(predictor): + # Output checks + assert isinstance(predictor, OCRPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, Document) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +def _test_kiepredictor(predictor): + # Output checks + assert isinstance(predictor, KIEPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, KIEDocument) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [ + ["db_mobilenet_v3_large", "crnn_mobilenet_v3_large"], + ], +) +def test_zoo_models(det_arch, reco_arch): + # Model + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) + _test_predictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.ocr_predictor(det_model, reco_model) + _test_predictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.ocr_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.ocr_predictor(reco_arch=det_model, pretrained=True) + + # KIE predictor + predictor = models.kie_predictor(det_arch, reco_arch, pretrained=True) + _test_kiepredictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.kie_predictor(det_model, reco_model) + _test_kiepredictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.kie_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.kie_predictor(reco_arch=det_model, pretrained=True) diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py new file mode 100644 index 0000000..76bc84e --- /dev/null +++ b/tests/pytorch/test_transforms_pt.py @@ -0,0 +1,351 @@ +import math + +import numpy as np +import pytest +import torch +from doctr.transforms import ( + ChannelShuffle, + ColorInversion, + GaussianNoise, + RandomCrop, + RandomHorizontalFlip, + RandomResize, + RandomRotate, + RandomShadow, + Resize, +) +from doctr.transforms.functional import crop_detection, rotate_sample + + +def test_resize(): + output_size = (32, 32) + transfo = Resize(output_size) + input_t = torch.ones((3, 64, 64), dtype=torch.float32) + out = transfo(input_t) + + assert torch.all(out == 1) + assert out.shape[-2:] == output_size + assert repr(transfo) == f"Resize(output_size={output_size}, interpolation='bilinear')" + + transfo = Resize(output_size, preserve_aspect_ratio=True) + input_t = torch.ones((3, 32, 64), dtype=torch.float32) + out = transfo(input_t) + + assert out.shape[-2:] == output_size + assert not torch.all(out == 1) + # Asymetric padding + assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1) + + # Symetric padding + transfo = Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True) + assert repr(transfo) == ( + f"Resize(output_size={output_size}, interpolation='bilinear', " + f"preserve_aspect_ratio=True, symmetric_pad=True)" + ) + out = transfo(input_t) + assert out.shape[-2:] == output_size + # symetric padding + assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 0) + + # Inverse aspect ratio + input_t = torch.ones((3, 64, 32), dtype=torch.float32) + out = transfo(input_t) + + assert not torch.all(out == 1) + assert out.shape[-2:] == output_size + + # Same aspect ratio + output_size = (32, 128) + transfo = Resize(output_size, preserve_aspect_ratio=True) + out = transfo(torch.ones((3, 16, 64), dtype=torch.float32)) + assert out.shape[-2:] == output_size + + # FP16 + input_t = torch.ones((3, 64, 64), dtype=torch.float16) + out = transfo(input_t) + assert out.dtype == torch.float16 + + +@pytest.mark.parametrize( + "rgb_min", + [ + 0.2, + 0.4, + 0.6, + ], +) +def test_invert_colorize(rgb_min): + transfo = ColorInversion(min_val=rgb_min) + input_t = torch.ones((8, 3, 32, 32), dtype=torch.float32) + out = transfo(input_t) + assert torch.all(out <= 1 - rgb_min + 1e-4) + assert torch.all(out >= 0) + + input_t = torch.full((8, 3, 32, 32), 255, dtype=torch.uint8) + out = transfo(input_t) + assert torch.all(out <= int(math.ceil(255 * (1 - rgb_min + 1e-4)))) + assert torch.all(out >= 0) + + # FP16 + input_t = torch.ones((8, 3, 32, 32), dtype=torch.float16) + out = transfo(input_t) + assert out.dtype == torch.float16 + + +def test_rotate_sample(): + img = torch.ones((3, 200, 100), dtype=torch.float32) + boxes = np.array([0, 0, 100, 200])[None, ...] + polys = np.stack((boxes[..., [0, 1]], boxes[..., [2, 1]], boxes[..., [2, 3]], boxes[..., [0, 3]]), axis=1) + rel_boxes = np.array([0, 0, 1, 1], dtype=np.float32)[None, ...] + rel_polys = np.stack( + (rel_boxes[..., [0, 1]], rel_boxes[..., [2, 1]], rel_boxes[..., [2, 3]], rel_boxes[..., [0, 3]]), axis=1 + ) + + # No angle + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, False) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, True) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, False) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, True) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + + # No expansion + expected_img = torch.zeros((3, 200, 100), dtype=torch.float32) + expected_img[:, 50:150] = 1 + expected_polys = np.array([[0, 0.75], [0, 0.25], [1, 0.25], [1, 0.75]])[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + # Expansion + expected_img = torch.ones((3, 100, 200), dtype=torch.float32) + expected_polys = np.array([[0, 1], [0, 0], [1, 0], [1, 1]], dtype=np.float32)[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + with pytest.raises(AssertionError): + rotate_sample(img, boxes[None, ...], 90, False) + + +def test_random_rotate(): + rotator = RandomRotate(max_angle=10.0, expand=False) + input_t = torch.ones((3, 50, 50), dtype=torch.float32) + boxes = np.array([[15, 20, 35, 30]]) + r_img, _r_boxes = rotator(input_t, boxes) + assert r_img.shape == input_t.shape + + rotator = RandomRotate(max_angle=10.0, expand=True) + r_img, _r_boxes = rotator(input_t, boxes) + assert r_img.shape != input_t.shape + + # FP16 (only on GPU) + if torch.cuda.is_available(): + input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() + r_img, _ = rotator(input_t, boxes) + assert r_img.dtype == torch.float16 + + +def test_crop_detection(): + img = torch.ones((3, 50, 50), dtype=torch.float32) + abs_boxes = np.array([ + [15, 20, 35, 30], + [5, 10, 10, 20], + ]) + crop_box = (12 / 50, 23 / 50, 50 / 50, 50 / 50) + c_img, c_boxes = crop_detection(img, abs_boxes, crop_box) + assert c_img.shape == (3, 26, 37) + assert c_boxes.shape == (1, 4) + assert np.all(c_boxes == np.array([15 - 12, 0, 35 - 12, 30 - 23])[None, ...]) + + rel_boxes = np.array([ + [0.3, 0.4, 0.7, 0.6], + [0.1, 0.2, 0.2, 0.4], + ]) + crop_box = (0.24, 0.46, 1.0, 1.0) + c_img, c_boxes = crop_detection(img, rel_boxes, crop_box) + assert c_img.shape == (3, 26, 37) + assert c_boxes.shape == (1, 4) + assert np.abs(c_boxes - np.array([0.06 / 0.76, 0.0, 0.46 / 0.76, 0.14 / 0.54])[None, ...]).mean() < 1e-7 + + # FP16 + img = torch.ones((3, 50, 50), dtype=torch.float16) + c_img, _ = crop_detection(img, abs_boxes, crop_box) + assert c_img.dtype == torch.float16 + + with pytest.raises(AssertionError): + crop_detection(img, abs_boxes, (2, 6, 24, 56)) + + +@pytest.mark.parametrize( + "target", + [ + np.array([[15, 20, 35, 30]]), # box + np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), # polygon + ], +) +def test_random_crop(target): + cropper = RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33)) + input_t = torch.ones((3, 50, 50), dtype=torch.float32) + img, target = cropper(input_t, target) + # Check the scale + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + # Check aspect ratio + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.5 + # Check the target + assert np.all(target >= 0) + if target.ndim == 2: + assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) + else: + assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2]) + + +@pytest.mark.parametrize( + "input_dtype, input_size", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + ], +) +def test_channel_shuffle(input_dtype, input_size): + transfo = ChannelShuffle() + input_t = torch.rand(input_size, dtype=torch.float32) + if input_dtype == torch.uint8: + input_t = (255 * input_t).round() + input_t = input_t.to(dtype=input_dtype) + out = transfo(input_t) + assert isinstance(out, torch.Tensor) + assert out.shape == input_size + assert out.dtype == input_dtype + # Ensure that nothing has changed apart from channel order + if input_dtype == torch.uint8: + assert torch.all(input_t.sum(0) == out.sum(0)) + else: + # Float approximation + assert (input_t.sum(0) - out.sum(0)).abs().mean() < 1e-7 + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + ], +) +def test_gaussian_noise(input_dtype, input_shape): + transform = GaussianNoise(0.0, 1.0) + input_t = torch.rand(input_shape, dtype=torch.float32) + if input_dtype == torch.uint8: + input_t = (255 * input_t).round() + input_t = input_t.to(dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, torch.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + assert torch.any(transformed != input_t) + assert torch.all(transformed >= 0) + if input_dtype == torch.uint8: + assert torch.all(transformed <= 255) + else: + assert torch.all(transformed <= 1.0) + + +@pytest.mark.parametrize( + "p,target", + [ + [1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [1, np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)], + [0, np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)], + ], +) +def test_randomhorizontalflip(p, target): + # testing for 2 cases, with flip probability 1 and 0. + transform = RandomHorizontalFlip(p) + input_t = torch.ones((3, 32, 32), dtype=torch.float32) + input_t[..., :16] = 0 + + transformed, _target = transform(input_t, target) + assert isinstance(transformed, torch.Tensor) + assert transformed.shape == input_t.shape + assert transformed.dtype == input_t.dtype + # integrity check of targets + assert isinstance(_target, np.ndarray) + assert _target.dtype == np.float32 + if _target.ndim == 2: + if p == 1: + assert np.all(_target == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([1] * 16 + [0] * 16, dtype=torch.float32)) + elif p == 0: + assert np.all(_target == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([0] * 16 + [1] * 16, dtype=torch.float32)) + else: + if p == 1: + assert np.all(_target == np.array([[[0.9, 0.1], [0.7, 0.1], [0.7, 0.4], [0.9, 0.4]]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([1] * 16 + [0] * 16, dtype=torch.float32)) + elif p == 0: + assert np.all(_target == np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([0] * 16 + [1] * 16, dtype=torch.float32)) + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + [torch.float32, (3, 64, 32)], + [torch.uint8, (3, 64, 32)], + ], +) +def test_random_shadow(input_dtype, input_shape): + transform = RandomShadow((0.2, 0.8)) + input_t = torch.ones(input_shape, dtype=torch.float32) + if input_dtype == torch.uint8: + input_t = (255 * input_t).round() + input_t = input_t.to(dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, torch.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + # The shadow will darken the picture + assert input_t.float().mean() >= transformed.float().mean() + assert torch.all(transformed >= 0) + if input_dtype == torch.uint8: + assert torch.all(transformed <= 255) + else: + assert torch.all(transformed <= 1.0) + + +@pytest.mark.parametrize( + "p,target", + [ + [1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [1, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)], + [0, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)], + ], +) +def test_random_resize(p, target): + transfo = RandomResize(scale_range=(0.3, 1.3), p=p) + assert repr(transfo) == f"RandomResize(scale_range=(0.3, 1.3), p={p})" + + img = torch.rand((3, 64, 64)) + # Apply the transformation + out_img, out_target = transfo(img, target) + assert isinstance(out_img, torch.Tensor) + assert isinstance(out_target, np.ndarray) + # Resize is already well tested + assert torch.all(out_img == img) if p == 0 else out_img.shape != img.shape + assert out_target.shape == target.shape diff --git a/tests/tensorflow/test_datasets_loader_tf.py b/tests/tensorflow/test_datasets_loader_tf.py new file mode 100644 index 0000000..26d24ae --- /dev/null +++ b/tests/tensorflow/test_datasets_loader_tf.py @@ -0,0 +1,75 @@ +from typing import List, Tuple + +import tensorflow as tf +from doctr.datasets import DataLoader + + +class MockDataset: + def __init__(self, input_size): + self.data: List[Tuple[float, bool]] = [ + (1, True), + (0, False), + (0.5, True), + ] + self.input_size = input_size + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + val, label = self.data[index] + return tf.cast(tf.fill(self.input_size, val), dtype=tf.float32), tf.constant(label, dtype=tf.bool) + + +class MockDatasetBis(MockDataset): + @staticmethod + def collate_fn(samples): + x, y = zip(*samples) + return tf.stack(x, axis=0), list(y) + + +def test_dataloader(): + loader = DataLoader( + MockDataset((32, 32)), + shuffle=True, + batch_size=2, + drop_last=True, + ) + + ds_iter = iter(loader) + num_batches = 0 + for x, y in ds_iter: + num_batches += 1 + assert len(loader) == 1 + assert num_batches == 1 + assert isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor) + assert x.shape == (2, 32, 32) + assert y.shape == (2,) + + # Drop last + loader = DataLoader( + MockDataset((32, 32)), + shuffle=True, + batch_size=2, + drop_last=False, + ) + ds_iter = iter(loader) + num_batches = 0 + for x, y in ds_iter: + num_batches += 1 + assert loader.num_batches == 2 + assert num_batches == 2 + + # Custom collate + loader = DataLoader( + MockDatasetBis((32, 32)), + shuffle=True, + batch_size=2, + drop_last=False, + ) + + ds_iter = iter(loader) + x, y = next(ds_iter) + assert isinstance(x, tf.Tensor) and isinstance(y, list) + assert x.shape == (2, 32, 32) + assert len(y) == 2 diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py new file mode 100644 index 0000000..4bbf946 --- /dev/null +++ b/tests/tensorflow/test_datasets_tf.py @@ -0,0 +1,605 @@ +import os +from shutil import move + +import numpy as np +import pytest +import tensorflow as tf +from doctr import datasets +from doctr.datasets import DataLoader +from doctr.file_utils import CLASS_NAME +from doctr.transforms import Resize + + +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): + # Fetch one sample + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape == (*input_size, 3) + assert img.dtype == tf.float32 + assert isinstance(target, dict) + assert isinstance(target["boxes"], np.ndarray) and target["boxes"].dtype == np.float32 + if is_polygons: + assert target["boxes"].ndim == 3 and target["boxes"].shape[1:] == (4, 2) + else: + assert target["boxes"].ndim == 2 and target["boxes"].shape[1:] == (4,) + assert np.all(np.logical_and(target["boxes"] <= 1, target["boxes"] >= 0)) + if class_indices: + assert isinstance(target["labels"], np.ndarray) and target["labels"].dtype == np.int64 + else: + assert isinstance(target["labels"], list) and all(isinstance(s, str) for s in target["labels"]) + assert len(target["labels"]) == len(target["boxes"]) + + # Check batching + loader = DataLoader(ds, batch_size=batch_size) + + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + + +def _validate_dataset_recognition_part(ds, input_size, batch_size=2): + # Fetch one sample + img, label = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape == (*input_size, 3) + assert img.dtype == tf.float32 + assert isinstance(label, str) + + # Check batching + loader = DataLoader(ds, batch_size=batch_size) + + images, labels = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + +def test_visiondataset(): + url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" + with pytest.raises(ValueError): + datasets.datasets.VisionDataset(url, download=False) + + dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True) + assert len(dataset) == 0 + assert repr(dataset) == "VisionDataset()" + + +def test_rotation_dataset(mock_image_folder): + input_size = (1024, 1024) + + ds = datasets.OrientationDataset(img_folder=mock_image_folder, img_transforms=Resize(input_size)) + assert len(ds) == 5 + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.dtype == tf.float32 + assert img.shape[:2] == input_size + # Prefilled rotation targets + assert isinstance(target, np.ndarray) and target.dtype == np.int64 + # check that all prefilled targets are 0 (degrees) + assert np.all(target == 0) + + loader = DataLoader(ds, batch_size=2) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + +def test_detection_dataset(mock_image_folder, mock_detection_label): + input_size = (1024, 1024) + + ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 5 + img, target_dict = ds[0] + target = target_dict[CLASS_NAME] + assert isinstance(img, tf.Tensor) + assert img.shape[:2] == input_size + assert img.dtype == tf.float32 + # Bounding boxes + assert isinstance(target_dict, dict) + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) + assert target.shape[1] == 4 + + loader = DataLoader(ds, batch_size=2) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, list) and all( + isinstance(elt, np.ndarray) for target in targets for elt in target.values() + ) + + # Rotated DS + rotated_ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + use_polygons=True, + ) + _, r_target = rotated_ds[0] + assert r_target[CLASS_NAME].shape[1:] == (4, 2) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_recognition_dataset(mock_image_folder, mock_recognition_label): + input_size = (32, 128) + ds = datasets.RecognitionDataset( + img_folder=mock_image_folder, + labels_path=mock_recognition_label, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + assert len(ds) == 5 + image, label = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(label, str) + + loader = DataLoader(ds, batch_size=2) + images, labels = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +@pytest.mark.parametrize( + "use_polygons", + [False, True], +) +def test_ocrdataset(mock_ocrdataset, use_polygons): + input_size = (512, 512) + + ds = datasets.OCRDataset( + *mock_ocrdataset, + img_transforms=Resize(input_size), + use_polygons=use_polygons, + ) + assert len(ds) == 3 + _validate_dataset(ds, input_size, is_polygons=use_polygons) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_charactergenerator(): + input_size = (32, 32) + vocab = "abcdef" + + ds = datasets.CharacterGenerator( + vocab=vocab, + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, label = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(label, int) and label < len(vocab) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, tf.Tensor) and targets.shape == (2,) + assert targets.dtype == tf.int32 + + +def test_wordgenerator(): + input_size = (32, 128) + wordlen_range = (1, 10) + vocab = "abcdef" + + ds = datasets.WordGenerator( + vocab=vocab, + min_chars=wordlen_range[0], + max_chars=wordlen_range[1], + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, target = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1] + assert all(char in vocab for char in target) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples", + [ + [[512, 512], 3], # Actual set has 2700 training samples and 300 test samples + ], +) +def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts): + # monkeypatch the path to temporary dataset + datasets.DocArtefacts.URL = mock_doc_artefacts + datasets.DocArtefacts.SHA256 = None + + ds = datasets.DocArtefacts( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + cache_dir="/".join(mock_doc_artefacts.split("/")[:-2]), + cache_subdir=mock_doc_artefacts.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"DocArtefacts(train={True})" + _validate_dataset(ds, input_size, class_indices=True, is_polygons=rotate) + + +# NOTE: following datasets support also recognition task + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True], # recognition + ], +) +def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): + # monkeypatch the path to temporary dataset + datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") + + ds = datasets.SROIE( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SROIE(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 5, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True], # recognition + ], +) +def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): + ds = datasets.IC13( + *mock_ic13, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + + assert len(ds) == num_samples + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True], # recognition + ], +) +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): + ds = datasets.IMGUR5K( + *mock_imgur5k, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + + assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split + assert repr(ds) == f"IMGUR5K(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True], # recognition + ], +) +def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): + # monkeypatch the path to temporary dataset + datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") + + ds = datasets.SVHN( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVHN(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True], # recognition + ], +) +def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): + # monkeypatch the path to temporary dataset + datasets.FUNSD.URL = mock_funsd_dataset + datasets.FUNSD.SHA256 = None + datasets.FUNSD.FILE_NAME = "funsd.zip" + + ds = datasets.FUNSD( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"FUNSD(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True], # recognition + ], +) +def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): + # monkeypatch the path to temporary dataset + datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") + + ds = datasets.CORD( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"CORD(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True], # recognition + ], +) +def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): + # monkeypatch the path to temporary dataset + datasets.SynthText.URL = mock_synthtext_dataset + datasets.SynthText.SHA256 = None + + ds = datasets.SynthText( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SynthText(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True], # recognition + ], +) +def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): + # monkeypatch the path to temporary dataset + datasets.IIIT5K.URL = mock_iiit5k_dataset + datasets.IIIT5K.SHA256 = None + + ds = datasets.IIIT5K( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IIIT5K(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size, batch_size=1) + else: + _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True], # recognition + ], +) +def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): + # monkeypatch the path to temporary dataset + datasets.SVT.URL = mock_svt_dataset + datasets.SVT.SHA256 = None + + ds = datasets.SVT( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVT(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True], # recognition + ], +) +def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): + # monkeypatch the path to temporary dataset + datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") + + ds = datasets.IC03( + train=True, + download=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IC03(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize("rotate", [True, False]) +@pytest.mark.parametrize( + "input_size, num_samples, recognition", + [ + [[512, 512], 2, False], + [[32, 128], 5, True], + ], +) +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): + ds = datasets.WILDRECEIPT( + *mock_wildreceipt_dataset, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + recognition_task=recognition, + ) + assert len(ds) == num_samples + assert repr(ds) == f"WILDRECEIPT(train={True})" + if recognition: + _validate_dataset_recognition_part(ds, input_size) + else: + _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size) + + +def test_iiithws_dataset(mock_iiithws_dataset): + input_size = (32, 128) + ds = datasets.IIITHWS( + *mock_iiithws_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7141797 train and 793533 test samples + assert repr(ds) == f"IIITHWS(train={True})" + _validate_dataset_recognition_part(ds, input_size) diff --git a/tests/tensorflow/test_file_utils_tf.py b/tests/tensorflow/test_file_utils_tf.py new file mode 100644 index 0000000..a28709d --- /dev/null +++ b/tests/tensorflow/test_file_utils_tf.py @@ -0,0 +1,5 @@ +from doctr.file_utils import is_tf_available + + +def test_file_utils(): + assert is_tf_available() diff --git a/tests/tensorflow/test_io_image_tf.py b/tests/tensorflow/test_io_image_tf.py new file mode 100644 index 0000000..1680fd2 --- /dev/null +++ b/tests/tensorflow/test_io_image_tf.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +import tensorflow as tf +from doctr.io import decode_img_as_tensor, read_img_as_tensor, tensor_from_numpy + + +def test_read_img_as_tensor(mock_image_path): + img = read_img_as_tensor(mock_image_path) + + assert isinstance(img, tf.Tensor) + assert img.dtype == tf.float32 + assert img.shape == (900, 1200, 3) + + img = read_img_as_tensor(mock_image_path, dtype=tf.float16) + assert img.dtype == tf.float16 + img = read_img_as_tensor(mock_image_path, dtype=tf.uint8) + assert img.dtype == tf.uint8 + + with pytest.raises(ValueError): + read_img_as_tensor(mock_image_path, dtype=tf.float64) + + +def test_decode_img_as_tensor(mock_image_stream): + img = decode_img_as_tensor(mock_image_stream) + + assert isinstance(img, tf.Tensor) + assert img.dtype == tf.float32 + assert img.shape == (900, 1200, 3) + + img = decode_img_as_tensor(mock_image_stream, dtype=tf.float16) + assert img.dtype == tf.float16 + img = decode_img_as_tensor(mock_image_stream, dtype=tf.uint8) + assert img.dtype == tf.uint8 + + with pytest.raises(ValueError): + decode_img_as_tensor(mock_image_stream, dtype=tf.float64) + + +def test_tensor_from_numpy(mock_image_stream): + with pytest.raises(ValueError): + tensor_from_numpy(np.zeros((256, 256, 3)), tf.int64) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8)) + + assert isinstance(out, tf.Tensor) + assert out.dtype == tf.float32 + assert out.shape == (256, 256, 3) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=tf.float16) + assert out.dtype == tf.float16 + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=tf.uint8) + assert out.dtype == tf.uint8 diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py new file mode 100644 index 0000000..f6fc767 --- /dev/null +++ b/tests/tensorflow/test_models_classification_tf.py @@ -0,0 +1,227 @@ +import os +import tempfile + +import cv2 +import numpy as np +import onnxruntime +import psutil +import pytest +import tensorflow as tf +from doctr.models import classification +from doctr.models.classification.predictor import OrientationPredictor +from doctr.models.utils import export_model_to_onnx + +system_available_memory = int(psutil.virtual_memory().available / 1024**3) + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (32, 32, 3), (126,)], + ["resnet18", (32, 32, 3), (126,)], + ["resnet31", (32, 32, 3), (126,)], + ["resnet34", (32, 32, 3), (126,)], + ["resnet34_wide", (32, 32, 3), (126,)], + ["resnet50", (32, 32, 3), (126,)], + ["magc_resnet31", (32, 32, 3), (126,)], + ["mobilenet_v3_small", (32, 32, 3), (126,)], + ["mobilenet_v3_large", (32, 32, 3), (126,)], + ["vit_s", (32, 32, 3), (126,)], + ["vit_b", (32, 32, 3), (126,)], + ["textnet_tiny", (32, 32, 3), (126,)], + ["textnet_small", (32, 32, 3), (126,)], + ["textnet_base", (32, 32, 3), (126,)], + ], +) +def test_classification_architectures(arch_name, input_shape, output_size): + # Model + batch_size = 2 + tf.keras.backend.clear_session() + model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) + # Forward + out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) + # Output checks + assert isinstance(out, tf.Tensor) + assert out.dtype == tf.float32 + assert out.numpy().shape == (batch_size, *output_size) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["mobilenet_v3_small_crop_orientation", (256, 256, 3)], + ["mobilenet_v3_small_page_orientation", (512, 512, 3)], + ], +) +def test_classification_models(arch_name, input_shape): + batch_size = 8 + reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + + out = reco_model(input_tensor) + assert isinstance(out, tf.Tensor) + assert out.shape.as_list() == [8, 4] + + +@pytest.mark.parametrize( + "arch_name", + [ + "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", + ], +) +def test_classification_zoo(arch_name): + if "crop" in arch_name: + batch_size = 16 + input_tensor = tf.random.uniform(shape=[batch_size, 256, 256, 3], minval=0, maxval=1) + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + + with pytest.raises(ValueError): + predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + else: + batch_size = 2 + input_tensor = tf.random.uniform(shape=[batch_size, 512, 512, 3], minval=0, maxval=1) + # Model + predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False) + + with pytest.raises(ValueError): + predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False) + # object check + assert isinstance(predictor, OrientationPredictor) + out = predictor(input_tensor) + class_idxs, classes, confs = out[0], out[1], out[2] + assert isinstance(class_idxs, list) and len(class_idxs) == batch_size + assert isinstance(classes, list) and len(classes) == batch_size + assert isinstance(confs, list) and len(confs) == batch_size + assert all(isinstance(pred, int) for pred in class_idxs) + assert all(isinstance(pred, int) for pred in classes) and all(pred in [0, 90, 180, -90] for pred in classes) + assert all(isinstance(pred, float) for pred in confs) + + +def test_crop_orientation_model(mock_text_box): + text_box_0 = cv2.imread(mock_text_box) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + + +# TODO: uncomment when model is available +""" +def test_page_orientation_model(mock_payslip): + text_box_0 = cv2.imread(mock_payslip) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) +""" + + +# temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue) +# ref.: https://github.com/mindee/doctr/pull/1201 +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (32, 32, 3), (126,)], + ["mobilenet_v3_small", (512, 512, 3), (126,)], + ["mobilenet_v3_large", (512, 512, 3), (126,)], + ["mobilenet_v3_small_crop_orientation", (256, 256, 3), (4,)], + ["mobilenet_v3_small_page_orientation", (512, 512, 3), (4,)], + ["resnet18", (32, 32, 3), (126,)], + ["vit_s", (32, 32, 3), (126,)], + ["textnet_tiny", (32, 32, 3), (126,)], + ["textnet_small", (32, 32, 3), (126,)], + pytest.param( + "resnet31", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "resnet34", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "resnet34_wide", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "resnet50", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "magc_resnet31", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "vit_b", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "textnet_base", + (32, 32, 3), + (126,), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + ], +) +def test_models_onnx_export(arch_name, input_shape, output_size): + # Model + batch_size = 2 + tf.keras.backend.clear_session() + if "orientation" in arch_name: + model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) + else: + model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) + + if arch_name == "vit_b" or arch_name == "vit_s": + # vit model needs a fixed batch size + dummy_input = [tf.TensorSpec([2, *input_shape], tf.float32, name="input")] + else: + # batch_size = None for dynamic batch size + dummy_input = [tf.TensorSpec([None, *input_shape], tf.float32, name="input")] + + np_dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32) + tf_logits = model(np_dummy_input, training=False).numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path, output = export_model_to_onnx( + model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input + ) + + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(output, {"input": np_dummy_input}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == (batch_size, *output_size) + # Check that the output is close to the TensorFlow output - only warn if not close + try: + assert np.allclose(tf_logits, ort_outs[0], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(tf_logits - ort_outs[0]))}") diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py new file mode 100644 index 0000000..24fe231 --- /dev/null +++ b/tests/tensorflow/test_models_detection_tf.py @@ -0,0 +1,270 @@ +import math +import os +import tempfile + +import numpy as np +import onnxruntime +import psutil +import pytest +import tensorflow as tf +from doctr.file_utils import CLASS_NAME +from doctr.io import DocumentFile +from doctr.models import detection +from doctr.models.detection._utils import dilate, erode +from doctr.models.detection.fast.tensorflow import reparameterize +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import export_model_to_onnx + +system_available_memory = int(psutil.virtual_memory().available / 1024**3) + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape, output_size, out_prob", + [ + ["db_resnet50", (512, 512, 3), (512, 512, 1), True], + ["db_mobilenet_v3_large", (512, 512, 3), (512, 512, 1), True], + ["linknet_resnet18", (512, 512, 3), (512, 512, 1), True], + ["linknet_resnet34", (512, 512, 3), (512, 512, 1), True], + ["linknet_resnet50", (512, 512, 3), (512, 512, 1), True], + ["fast_tiny", (512, 512, 3), (512, 512, 1), True], + ["fast_tiny_rep", (512, 512, 3), (512, 512, 1), True], # Reparameterized model + ["fast_small", (512, 512, 3), (512, 512, 1), True], + ["fast_base", (512, 512, 3), (512, 512, 1), True], + ], +) +def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): + batch_size = 2 + tf.keras.backend.clear_session() + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True, input_shape=input_shape)) + train_mode = False # Reparameterized model is not trainable + else: + model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + target = [ + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32)}, + ] + # test training model + out = model( + input_tensor, + target, + return_model_output=True, + return_preds=not train_mode, + training=train_mode, + ) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + # Check proba map + assert isinstance(out["out_map"], tf.Tensor) + assert out["out_map"].dtype == tf.float32 + seg_map = out["out_map"].numpy() + assert seg_map.shape == (batch_size, *output_size) + if out_prob: + assert np.all(np.logical_and(seg_map >= 0, seg_map <= 1)) + # Check boxes + if not train_mode: + for boxes_dict in out["preds"]: + for boxes in boxes_dict.values(): + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + # Check loss + assert isinstance(out["loss"], tf.Tensor) + # Target checks + target = [ + {CLASS_NAME: np.array([[0, 0, 1, 1]], dtype=np.uint8)}, + {CLASS_NAME: np.array([[0, 0, 1, 1]], dtype=np.uint8)}, + ] + with pytest.raises(AssertionError): + out = model(input_tensor, target, training=True) + + target = [ + {CLASS_NAME: np.array([[0, 0, 1.5, 1.5]], dtype=np.float32)}, + {CLASS_NAME: np.array([[-0.2, -0.3, 1, 1]], dtype=np.float32)}, + ] + with pytest.raises(ValueError): + out = model(input_tensor, target, training=True) + + # Check the rotated case + target = [ + {CLASS_NAME: np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.65, 0.3, 0.3, 0]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.7, 0.3, 0.4, 0]], dtype=np.float32)}, + ] + loss = model(input_tensor, target, training=True)["loss"] + assert isinstance(loss, tf.Tensor) and ((loss - out["loss"]) / loss).numpy() < 1 + + +@pytest.fixture(scope="session") +def test_detectionpredictor(mock_pdf): + batch_size = 4 + predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=batch_size), detection.db_resnet50(input_shape=(512, 512, 3)) + ) + + pages = DocumentFile.from_pdf(mock_pdf).as_images() + out = predictor(pages) + # The input PDF has 2 pages + assert len(out) == 2 + + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + return predictor + + +@pytest.fixture(scope="session") +def test_rotated_detectionpredictor(mock_pdf): + batch_size = 4 + predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=batch_size), + detection.db_resnet50(assume_straight_pages=False, input_shape=(512, 512, 3)), + ) + + pages = DocumentFile.from_pdf(mock_pdf).as_images() + out = predictor(pages) + + # The input PDF has 2 pages + assert len(out) == 2 + + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + return predictor + + +@pytest.mark.parametrize( + "arch_name", + [ + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "fast_tiny", + ], +) +def test_detection_zoo(arch_name): + # Model + tf.keras.backend.clear_session() + predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, DetectionPredictor) + input_tensor = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1) + out, seq_maps = predictor(input_tensor, return_maps=True) + assert all(isinstance(boxes, dict) for boxes in out) + assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) + assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps) + assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps) + # check that all values in the seq_maps are between 0 and 1 + assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps) + + +def test_detection_zoo_error(): + with pytest.raises(ValueError): + _ = detection.zoo.detection_predictor("my_fancy_model", pretrained=False) + + +def test_fast_reparameterization(): + dummy_input = tf.random.uniform(shape=[1, 1024, 1024, 3], minval=0, maxval=1) + base_model = detection.fast_tiny(pretrained=True, exportable=True) + base_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) + assert math.isclose(base_model_params, 13535296) # base model params + base_out = base_model(dummy_input, training=False)["logits"] + tf.keras.backend.clear_session() + rep_model = reparameterize(base_model) + rep_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) + assert math.isclose(rep_model_params, 8520256) # reparameterized model params + rep_out = rep_model(dummy_input, training=False)["logits"] + diff = base_out - rep_out + assert np.mean(diff) < 5e-2 + + +def test_erode(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.zeros((1, 3, 3, 1)) + out = erode(x, 3) + assert tf.math.reduce_all(out == expected) + + +def test_dilate(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.ones((1, 3, 3, 1)) + out = dilate(x, 3) + assert tf.math.reduce_all(out == expected) + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["db_mobilenet_v3_large", (512, 512, 3), (512, 512, 1)], + ["linknet_resnet18", (1024, 1024, 3), (1024, 1024, 1)], + ["fast_tiny", (1024, 1024, 3), (1024, 1024, 1)], + ["fast_tiny_rep", (1024, 1024, 3), (1024, 1024, 1)], # Reparameterized model + ["fast_small", (1024, 1024, 3), (1024, 1024, 1)], + pytest.param( + "db_resnet50", + (512, 512, 3), + (512, 512, 1), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "linknet_resnet34", + (1024, 1024, 3), + (1024, 1024, 1), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "linknet_resnet50", + (512, 512, 3), + (512, 512, 1), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "fast_base", + (512, 512, 3), + (512, 512, 1), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + ], +) +def test_models_onnx_export(arch_name, input_shape, output_size): + # Model + batch_size = 2 + tf.keras.backend.clear_session() + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True, input_shape=input_shape)) + else: + model = detection.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) + # batch_size = None for dynamic batch size + dummy_input = [tf.TensorSpec([None, *input_shape], tf.float32, name="input")] + np_dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32) + tf_logits = model(np_dummy_input, training=False)["logits"].numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path, output = export_model_to_onnx( + model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input + ) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(output, {"input": np_dummy_input}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == (batch_size, *output_size) + # Check that the output is close to the TensorFlow output - only warn if not close + try: + assert np.allclose(ort_outs[0], tf_logits, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(tf_logits - ort_outs[0]))}") diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py new file mode 100644 index 0000000..e470a4d --- /dev/null +++ b/tests/tensorflow/test_models_factory.py @@ -0,0 +1,70 @@ +import json +import os +import tempfile + +import pytest +import tensorflow as tf +from doctr import models +from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub + + +def test_push_to_hf_hub(): + model = models.classification.resnet18(pretrained=False) + with pytest.raises(ValueError): + # run_config and/or arch must be specified + push_to_hf_hub(model, model_name="test", task="classification") + with pytest.raises(ValueError): + # task must be one of classification, detection, recognition, obj_detection + push_to_hf_hub(model, model_name="test", task="invalid_task", arch="mobilenet_v3_small") + with pytest.raises(ValueError): + # arch not in available architectures for task + push_to_hf_hub(model, model_name="test", task="detection", arch="crnn_mobilenet_v3_large") + + +@pytest.mark.parametrize( + "arch_name, task_name, dummy_model_id", + [ + ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-tf-vgg16-bn-r"], + ["resnet18", "classification", "Felix92/doctr-dummy-tf-resnet18"], + ["resnet31", "classification", "Felix92/doctr-dummy-tf-resnet31"], + ["resnet34", "classification", "Felix92/doctr-dummy-tf-resnet34"], + ["resnet34_wide", "classification", "Felix92/doctr-dummy-tf-resnet34-wide"], + ["resnet50", "classification", "Felix92/doctr-dummy-tf-resnet50"], + ["magc_resnet31", "classification", "Felix92/doctr-dummy-tf-magc-resnet31"], + ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-tf-mobilenet-v3-large"], + ["vit_b", "classification", "Felix92/doctr-dummy-tf-vit-b"], + ["textnet_tiny", "classification", "Felix92/doctr-dummy-tf-textnet-tiny"], + ["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50"], + ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large"], + ["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18"], + ["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34"], + ["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50"], + ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn"], + ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-tf-crnn-mobilenet-v3-large"], + ["sar_resnet31", "recognition", "Felix92/doctr-dummy-tf-sar-resnet31"], + ["master", "recognition", "Felix92/doctr-dummy-tf-master"], + ["vitstr_small", "recognition", "Felix92/doctr-dummy-tf-vitstr-small"], + ["parseq", "recognition", "Felix92/doctr-dummy-tf-parseq"], + ], +) +def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): + with tempfile.TemporaryDirectory() as tmp_dir: + tf.keras.backend.clear_session() + model = models.__dict__[task_name].__dict__[arch_name](pretrained=True) + + _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) + + assert hasattr(model, "cfg") + assert len(os.listdir(tmp_dir)) == 2 + assert os.path.exists(tmp_dir + "/tf_model") + assert len(os.listdir(tmp_dir + "/tf_model")) == 3 + assert os.path.exists(tmp_dir + "/config.json") + tmp_config = json.load(open(tmp_dir + "/config.json")) + assert arch_name == tmp_config["arch"] + assert task_name == tmp_config["task"] + assert all(key in model.cfg.keys() for key in tmp_config.keys()) + + # test from hub + tf.keras.backend.clear_session() + hub_model = from_hub(repo_id=dummy_model_id) + assert isinstance(hub_model, type(model)) diff --git a/tests/tensorflow/test_models_preprocessor_tf.py b/tests/tensorflow/test_models_preprocessor_tf.py new file mode 100644 index 0000000..d1f2151 --- /dev/null +++ b/tests/tensorflow/test_models_preprocessor_tf.py @@ -0,0 +1,43 @@ +import numpy as np +import pytest +import tensorflow as tf +from doctr.models.preprocessor import PreProcessor + + +@pytest.mark.parametrize( + "batch_size, output_size, input_tensor, expected_batches, expected_value", + [ + [2, (128, 128), np.full((3, 256, 128, 3), 255, dtype=np.uint8), 1, 0.5], # numpy uint8 + [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float32), 1, 0.5], # numpy fp32 + [2, (128, 128), tf.cast(tf.fill((3, 256, 128, 3), 255), dtype=tf.uint8), 1, 0.5], # tf uint8 + [2, (128, 128), tf.ones((3, 128, 128, 3), dtype=tf.float32), 1, 0.5], # tf fp32 + [2, (128, 128), [np.full((256, 128, 3), 255, dtype=np.uint8)] * 3, 2, 0.5], # list of numpy uint8 + [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float32)] * 3, 2, 0.5], # list of numpy fp32 + [2, (128, 128), [tf.cast(tf.fill((256, 128, 3), 255), dtype=tf.uint8)] * 3, 2, 0.5], # list of tf uint8 + [2, (128, 128), [tf.ones((128, 128, 3), dtype=tf.float32)] * 3, 2, 0.5], # list of tf fp32 + ], +) +def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, expected_value): + processor = PreProcessor(output_size, batch_size) + + # Invalid input type + with pytest.raises(TypeError): + processor(42) + # 4D check + with pytest.raises(AssertionError): + processor(np.full((256, 128, 3), 255, dtype=np.uint8)) + with pytest.raises(TypeError): + processor(np.full((1, 256, 128, 3), 255, dtype=np.int32)) + # 3D check + with pytest.raises(AssertionError): + processor([np.full((3, 256, 128, 3), 255, dtype=np.uint8)]) + with pytest.raises(TypeError): + processor([np.full((256, 128, 3), 255, dtype=np.int32)]) + + out = processor(input_tensor) + assert isinstance(out, list) and len(out) == expected_batches + assert all(isinstance(b, tf.Tensor) for b in out) + assert all(b.dtype == tf.float32 for b in out) + assert all(b.shape[1:3] == output_size for b in out) + assert all(tf.math.reduce_all(tf.math.abs(b - expected_value) < 1e-6) for b in out) + assert len(repr(processor).split("\n")) == 4 diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py new file mode 100644 index 0000000..d562abf --- /dev/null +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -0,0 +1,233 @@ +import os +import shutil +import tempfile + +import numpy as np +import onnxruntime +import psutil +import pytest +import tensorflow as tf +from doctr.io import DocumentFile +from doctr.models import recognition +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.crnn.tensorflow import CTCPostProcessor +from doctr.models.recognition.master.tensorflow import MASTERPostProcessor +from doctr.models.recognition.parseq.tensorflow import PARSeqPostProcessor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.sar.tensorflow import SARPostProcessor +from doctr.models.recognition.vitstr.tensorflow import ViTSTRPostProcessor +from doctr.models.utils import export_model_to_onnx +from doctr.utils.geometry import extract_crops + +system_available_memory = int(psutil.virtual_memory().available / 1024**3) + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (32, 128, 3)], + ["crnn_mobilenet_v3_small", (32, 128, 3)], + ["crnn_mobilenet_v3_large", (32, 128, 3)], + ["sar_resnet31", (32, 128, 3)], + ["master", (32, 128, 3)], + ["vitstr_small", (32, 128, 3)], + ["vitstr_base", (32, 128, 3)], + ["parseq", (32, 128, 3)], + ], +) +def test_recognition_models(arch_name, input_shape, train_mode): + batch_size = 4 + reco_model = recognition.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + target = ["i", "am", "a", "jedi"] + + out = reco_model( + input_tensor, + target, + return_model_output=True, + return_preds=not train_mode, + training=train_mode, + ) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + assert isinstance(out["out_map"], tf.Tensor) + assert out["out_map"].dtype == tf.float32 + if not train_mode: + assert isinstance(out["preds"], list) + assert len(out["preds"]) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in out["preds"]) + assert isinstance(out["loss"], tf.Tensor) + # test model in train mode needs targets + with pytest.raises(ValueError): + reco_model(input_tensor, None, training=True) + + +@pytest.mark.parametrize( + "post_processor, input_shape", + [ + [SARPostProcessor, [2, 30, 119]], + [CTCPostProcessor, [2, 30, 119]], + [MASTERPostProcessor, [2, 30, 119]], + [ViTSTRPostProcessor, [2, 30, 119]], + [PARSeqPostProcessor, [2, 30, 119]], + ], +) +def test_reco_postprocessors(post_processor, input_shape, mock_vocab): + processor = post_processor(mock_vocab) + decoded = processor(tf.random.uniform(shape=input_shape, minval=0, maxval=1, dtype=tf.float32)) + assert isinstance(decoded, list) + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in decoded) + assert len(decoded) == input_shape[0] + assert all(char in mock_vocab for word, _ in decoded for char in word) + # Repr + assert repr(processor) == f"{post_processor.__name__}(vocab_size={len(mock_vocab)})" + + +@pytest.fixture(scope="session") +def test_recognitionpredictor(mock_pdf, mock_vocab): + batch_size = 4 + predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=batch_size, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(vocab=mock_vocab, input_shape=(32, 128, 3)), + ) + + pages = DocumentFile.from_pdf(mock_pdf) + # Create bounding boxes + boxes = np.array([[0.5, 0.5, 0.75, 0.75], [0.5, 0.5, 1.0, 1.0]], dtype=np.float32) + crops = extract_crops(pages[0], boxes) + + out = predictor(crops) + + # One prediction per crop + assert len(out) == boxes.shape[0] + assert all(isinstance(val, str) and isinstance(conf, float) for val, conf in out) + + # Dimension check + with pytest.raises(ValueError): + input_crop = (255 * np.random.rand(1, 128, 64, 3)).astype(np.uint8) + _ = predictor([input_crop]) + + return predictor + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", + ], +) +def test_recognition_zoo(arch_name): + batch_size = 2 + # Model + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, RecognitionPredictor) + input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) for word, conf in out) + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + ], +) +def test_crnn_beam_search(arch_name): + batch_size = 2 + # Model + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, RecognitionPredictor) + input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) + out = predictor(input_tensor, beam_width=10, top_paths=10) + assert isinstance(out, list) and len(out) == batch_size + assert all( + isinstance(words, list) + and isinstance(confs, list) + and all(isinstance(word, str) for word in words) + and all(isinstance(conf, float) for conf in confs) + for words, confs in out + ) + + +def test_recognition_zoo_error(): + with pytest.raises(ValueError): + _ = recognition.zoo.recognition_predictor("my_fancy_model", pretrained=False) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (32, 128, 3)], + ["crnn_mobilenet_v3_small", (32, 128, 3)], + ["crnn_mobilenet_v3_large", (32, 128, 3)], + ["vitstr_small", (32, 128, 3)], # testing one vitstr version is enough + pytest.param( + "sar_resnet31", + (32, 128, 3), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + pytest.param( + "master", (32, 128, 3), marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory") + ), + pytest.param( + "parseq", + (32, 128, 3), + marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"), + ), + ], +) +def test_models_onnx_export(arch_name, input_shape): + # Model + batch_size = 2 + tf.keras.backend.clear_session() + model = recognition.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) + # SAR, MASTER, ViTSTR export currently only available with constant batch size + if arch_name in ["sar_resnet31", "master", "vitstr_small", "parseq"]: + dummy_input = [tf.TensorSpec([batch_size, *input_shape], tf.float32, name="input")] + else: + # batch_size = None for dynamic batch size + dummy_input = [tf.TensorSpec([None, *input_shape], tf.float32, name="input")] + np_dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32) + tf_logits = model(np_dummy_input, training=False)["logits"].numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path, output = export_model_to_onnx( + model, + model_name=os.path.join(tmpdir, "model"), + dummy_input=dummy_input, + large_model=True if arch_name == "master" else False, + ) + assert os.path.exists(model_path) + + if arch_name == "master": + # large models are exported as zip archive + shutil.unpack_archive(model_path, tmpdir, "zip") + model_path = os.path.join(tmpdir, "__MODEL_PROTO.onnx") + else: + model_path = os.path.join(tmpdir, "model.onnx") + + # Inference + ort_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + ort_outs = ort_session.run(output, {"input": np_dummy_input}) + + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == tf_logits.shape + # Check that the output is close to the TensorFlow output - only warn if not close + try: + assert np.allclose(tf_logits, ort_outs[0], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(tf_logits - ort_outs[0]))}") diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py new file mode 100644 index 0000000..3d35069 --- /dev/null +++ b/tests/tensorflow/test_models_utils_tf.py @@ -0,0 +1,60 @@ +import os + +import pytest +import tensorflow as tf +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_float32, + _copy_tensor, + conv_sequence, + load_pretrained_params, +) +from tensorflow.keras import Sequential, layers +from tensorflow.keras.applications import ResNet50 + + +def test_copy_tensor(): + x = tf.random.uniform(shape=[8], minval=0, maxval=1) + m = _copy_tensor(x) + assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x)) + + +def test_bf16_to_float32(): + x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16) + m = _bf16_to_float32(x) + assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32))) + + +def test_load_pretrained_params(tmpdir_factory): + model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)]) + # Retrieve this URL + url = "https://doctr-static.mindee.com/models?id=v0.1-models/tmp_checkpoint-4a98e492.zip&src=0" + # Temp cache dir + cache_dir = tmpdir_factory.mktemp("cache") + # Pass an incorrect hash + with pytest.raises(ValueError): + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir), internal_name="") + # Let tit resolve the hash from the file name + load_pretrained_params(model, url, cache_dir=str(cache_dir), internal_name="") + # Check that the file was downloaded & the archive extracted + assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492")) + # Check that archive was deleted + assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492.zip")) + + +def test_conv_sequence(): + assert len(conv_sequence(8, kernel_size=3)) == 1 + assert len(conv_sequence(8, "relu", kernel_size=3)) == 1 + assert len(conv_sequence(8, None, True, kernel_size=3)) == 2 + assert len(conv_sequence(8, "relu", True, kernel_size=3)) == 3 + + +def test_intermediate_layer_getter(): + backbone = ResNet50(include_top=False, weights=None, pooling=None) + feat_extractor = IntermediateLayerGetter(backbone, ["conv2_block3_out", "conv3_block4_out"]) + # Check num of output features + input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], minval=0, maxval=1) + assert len(feat_extractor(input_tensor)) == 2 + + # Repr + assert repr(feat_extractor) == "IntermediateLayerGetter()" diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py new file mode 100644 index 0000000..50b5d37 --- /dev/null +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -0,0 +1,325 @@ +import numpy as np +import pytest +from doctr import models +from doctr.file_utils import CLASS_NAME +from doctr.io import Document, DocumentFile +from doctr.io.elements import KIEDocument +from doctr.models import detection, recognition +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.detection.zoo import detection_predictor +from doctr.models.kie_predictor import KIEPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.zoo import recognition_predictor +from doctr.utils.repr import NestedObject + + +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=True, + pretrained_backbone=False, + input_shape=(512, 512, 3), + assume_straight_pages=assume_straight_pages, + ), + ) + + reco_bsize = 16 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, NestedObject) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + language = "unknown" + assert out.pages[0].language["value"] == language + + +def test_trained_ocr_predictor(mock_payslip): + doc = DocumentFile.from_images(mock_payslip) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + symmetric_pad=True, + preserve_aspect_ratio=False, + ) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + # test hooks + predictor.add_hook(_DummyCallback()) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) + assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr, rtol=0.05) + + assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised" + geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) + assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=True, + pretrained_backbone=False, + input_shape=(512, 512, 3), + assume_straight_pages=assume_straight_pages, + ), + ) + + reco_bsize = 16 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, NestedObject) + + out = predictor(doc) + assert isinstance(out, KIEDocument) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + language = "unknown" + assert out.pages[0].language["value"] == language + + +def test_trained_kie_predictor(mock_payslip): + doc = DocumentFile.from_images(mock_payslip) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + symmetric_pad=True, + preserve_aspect_ratio=False, + ) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + # test hooks + predictor.add_hook(_DummyCallback()) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr, rtol=0.05) + + assert out.pages[0].predictions[CLASS_NAME][3].value == "revised" + geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][3].geometry), geometry_revised, rtol=0.05) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + + +def _test_predictor(predictor): + # Output checks + assert isinstance(predictor, OCRPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, Document) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +def _test_kiepredictor(predictor): + # Output checks + assert isinstance(predictor, KIEPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, KIEDocument) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [ + ["db_mobilenet_v3_large", "crnn_vgg16_bn"], + ], +) +def test_zoo_models(det_arch, reco_arch): + # Model + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) + _test_predictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.ocr_predictor(det_model, reco_model) + _test_predictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.ocr_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.ocr_predictor(reco_arch=det_model, pretrained=True) + + # KIE predictor + predictor = models.kie_predictor(det_arch, reco_arch, pretrained=True) + _test_kiepredictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.kie_predictor(det_model, reco_model) + _test_kiepredictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.kie_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.kie_predictor(reco_arch=det_model, pretrained=True) diff --git a/tests/tensorflow/test_transforms_tf.py b/tests/tensorflow/test_transforms_tf.py new file mode 100644 index 0000000..d1db73d --- /dev/null +++ b/tests/tensorflow/test_transforms_tf.py @@ -0,0 +1,492 @@ +import math + +import numpy as np +import pytest +import tensorflow as tf +from doctr import transforms as T +from doctr.transforms.functional import crop_detection, rotate_sample + + +def test_resize(): + output_size = (32, 32) + transfo = T.Resize(output_size) + input_t = tf.cast(tf.fill([64, 64, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.math.reduce_all(tf.math.abs(out - 1) < 1e-6) + assert out.shape[:2] == output_size + assert repr(transfo) == f"Resize(output_size={output_size}, method='bilinear')" + + transfo = T.Resize(output_size, preserve_aspect_ratio=True) + input_t = tf.cast(tf.fill([32, 64, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert not tf.reduce_all(out == 1) + # Asymetric padding + assert tf.reduce_all(out[-1] == 0) and tf.math.reduce_all(tf.math.abs(out[0] - 1) < 1e-6) + assert out.shape[:2] == output_size + + # Symetric padding + transfo = T.Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True) + assert repr(transfo) == ( + f"Resize(output_size={output_size}, method='bilinear', " f"preserve_aspect_ratio=True, symmetric_pad=True)" + ) + out = transfo(input_t) + # Asymetric padding + assert tf.reduce_all(out[-1] == 0) and tf.reduce_all(out[0] == 0) + + # Inverse aspect ratio + input_t = tf.cast(tf.fill([64, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert not tf.reduce_all(out == 1) + assert out.shape[:2] == output_size + + # FP16 + input_t = tf.cast(tf.fill([64, 64, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_compose(): + output_size = (16, 16) + transfo = T.Compose([T.Resize((32, 32)), T.Resize(output_size)]) + input_t = tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1) + out = transfo(input_t) + + assert out.shape[:2] == output_size + assert len(repr(transfo).split("\n")) == 6 + + +@pytest.mark.parametrize( + "input_shape", + [ + [8, 32, 32, 3], + [32, 32, 3], + [32, 3], + ], +) +def test_normalize(input_shape): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + transfo = T.Normalize(mean, std) + input_t = tf.cast(tf.fill(input_shape, 1), dtype=tf.float32) + + out = transfo(input_t) + + assert tf.reduce_all(out == 1) + assert repr(transfo) == f"Normalize(mean={mean}, std={std})" + + # FP16 + input_t = tf.cast(tf.fill(input_shape, 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_lambatransformation(): + transfo = T.LambdaTransformation(lambda x: x / 2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out == 0.5) + + +def test_togray(): + transfo = T.ToGray() + r = tf.fill([8, 32, 32, 1], 0.2) + g = tf.fill([8, 32, 32, 1], 0.6) + b = tf.fill([8, 32, 32, 1], 0.7) + input_t = tf.cast(tf.concat([r, g, b], axis=-1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out <= 0.51) + assert tf.reduce_all(out >= 0.49) + + # FP16 + input_t = tf.cast(tf.concat([r, g, b], axis=-1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +@pytest.mark.parametrize( + "rgb_min", + [ + 0.2, + 0.4, + 0.6, + ], +) +def test_invert_colorize(rgb_min): + transfo = T.ColorInversion(min_val=rgb_min) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + assert tf.reduce_all(out <= 1 - rgb_min + 1e-4) + assert tf.reduce_all(out >= 0) + + input_t = tf.cast(tf.fill([8, 32, 32, 3], 255), dtype=tf.uint8) + out = transfo(input_t) + assert tf.reduce_all(out <= int(math.ceil(255 * (1 - rgb_min)))) + assert tf.reduce_all(out >= 0) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_brightness(): + transfo = T.RandomBrightness(max_delta=0.1) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out >= 0.4) + assert tf.reduce_all(out <= 0.6) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_contrast(): + transfo = T.RandomContrast(delta=0.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out == 0.5) + + # FP16 + if any(tf.config.list_physical_devices("GPU")): + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_saturation(): + transfo = T.RandomSaturation(delta=0.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float32) + input_t = tf.image.hsv_to_rgb(input_t) + out = transfo(input_t) + hsv = tf.image.rgb_to_hsv(out) + + assert tf.reduce_all(hsv[:, :, :, 1] >= 0.4) + assert tf.reduce_all(hsv[:, :, :, 1] <= 0.6) + + # FP16 + if any(tf.config.list_physical_devices("GPU")): + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_hue(): + transfo = T.RandomHue(max_delta=0.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float32) + input_t = tf.image.hsv_to_rgb(input_t) + out = transfo(input_t) + hsv = tf.image.rgb_to_hsv(out) + + assert tf.reduce_all(hsv[:, :, :, 0] <= 0.7) + assert tf.reduce_all(hsv[:, :, :, 0] >= 0.3) + + # FP16 + if any(tf.config.list_physical_devices("GPU")): + input_t = tf.cast(tf.fill([8, 32, 32, 3], 0.5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_gamma(): + transfo = T.RandomGamma(min_gamma=1.0, max_gamma=2.0, min_gain=0.8, max_gain=1.0) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.0), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out >= 1.6) + assert tf.reduce_all(out <= 4.0) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.0), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_jpegquality(): + transfo = T.RandomJpegQuality(min_quality=50) + input_t = tf.cast(tf.fill([32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + assert out.shape == input_t.shape + + # FP16 + input_t = tf.cast(tf.fill([32, 32, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_rotate_sample(): + img = tf.ones((200, 100, 3), dtype=tf.float32) + boxes = np.array([0, 0, 100, 200])[None, ...] + polys = np.stack((boxes[..., [0, 1]], boxes[..., [2, 1]], boxes[..., [2, 3]], boxes[..., [0, 3]]), axis=1) + rel_boxes = np.array([0, 0, 1, 1], dtype=np.float32)[None, ...] + rel_polys = np.stack( + (rel_boxes[..., [0, 1]], rel_boxes[..., [2, 1]], rel_boxes[..., [2, 3]], rel_boxes[..., [0, 3]]), axis=1 + ) + + # No angle + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, False) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, True) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, False) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, True) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + + # No expansion + expected_img = np.zeros((200, 100, 3), dtype=np.float32) + expected_img[50:150] = 1 + expected_img = tf.convert_to_tensor(expected_img) + expected_polys = np.array([[0, 0.75], [0, 0.25], [1, 0.25], [1, 0.75]])[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + # Expansion + expected_img = tf.ones((100, 200, 3), dtype=tf.float32) + expected_polys = np.array([[0, 1], [0, 0], [1, 0], [1, 1]], dtype=np.float32)[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + with pytest.raises(AssertionError): + rotate_sample(img, boxes[None, ...], 90, False) + + +def test_random_rotate(): + rotator = T.RandomRotate(max_angle=10.0, expand=False) + input_t = tf.ones((50, 50, 3), dtype=tf.float32) + boxes = np.array([[15, 20, 35, 30]]) + r_img, _r_boxes = rotator(input_t, boxes) + assert r_img.shape == input_t.shape + + rotator = T.RandomRotate(max_angle=10.0, expand=True) + r_img, _r_boxes = rotator(input_t, boxes) + assert r_img.shape != input_t.shape + + # FP16 + input_t = tf.ones((50, 50, 3), dtype=tf.float16) + r_img, _ = rotator(input_t, boxes) + assert r_img.dtype == tf.float16 + + +def test_crop_detection(): + img = tf.ones((50, 50, 3), dtype=tf.float32) + abs_boxes = np.array([ + [15, 20, 35, 30], + [5, 10, 10, 20], + ]) + crop_box = (12 / 50, 23 / 50, 1.0, 1.0) + c_img, c_boxes = crop_detection(img, abs_boxes, crop_box) + assert c_img.shape == (26, 37, 3) + assert c_boxes.shape == (1, 4) + assert np.all(c_boxes == np.array([15 - 12, 0, 35 - 12, 30 - 23])[None, ...]) + + rel_boxes = np.array([ + [0.3, 0.4, 0.7, 0.6], + [0.1, 0.2, 0.2, 0.4], + ]) + c_img, c_boxes = crop_detection(img, rel_boxes, crop_box) + assert c_img.shape == (26, 37, 3) + assert c_boxes.shape == (1, 4) + assert np.abs(c_boxes - np.array([0.06 / 0.76, 0.0, 0.46 / 0.76, 0.14 / 0.54])[None, ...]).mean() < 1e-7 + + # FP16 + img = tf.ones((50, 50, 3), dtype=tf.float16) + c_img, _ = crop_detection(img, rel_boxes, crop_box) + assert c_img.dtype == tf.float16 + + with pytest.raises(AssertionError): + crop_detection(img, abs_boxes, (2, 6, 24, 56)) + + +@pytest.mark.parametrize( + "target", + [ + np.array([[15, 20, 35, 30]]), # box + np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), # polygon + ], +) +def test_random_crop(target): + transfo = T.RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33)) + input_t = tf.ones((50, 50, 3), dtype=tf.float32) + img, target = transfo(input_t, target) + # Check the scale (take a margin) + assert img.shape[0] * img.shape[1] >= 0.4 * input_t.shape[0] * input_t.shape[1] + # Check aspect ratio (take a margin) + assert 0.65 <= img.shape[0] / img.shape[1] <= 1.5 + # Check the target + assert np.all(target >= 0) + if target.ndim == 2: + assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) + else: + assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2]) + + +def test_gaussian_blur(): + blur = T.GaussianBlur(3, (0.1, 3)) + input_t = np.ones((31, 31, 3), dtype=np.float32) + input_t[15, 15] = 0 + blur_img = blur(tf.convert_to_tensor(input_t)).numpy() + assert blur_img.shape == input_t.shape + assert np.all(blur_img[15, 15] > 0) + + +@pytest.mark.parametrize( + "input_dtype, input_size", + [ + [tf.float32, (32, 32, 3)], + [tf.uint8, (32, 32, 3)], + ], +) +def test_channel_shuffle(input_dtype, input_size): + transfo = T.ChannelShuffle() + input_t = tf.random.uniform(input_size, dtype=tf.float32) + if input_dtype == tf.uint8: + input_t = tf.math.round(255 * input_t) + input_t = tf.cast(input_t, dtype=input_dtype) + out = transfo(input_t) + assert isinstance(out, tf.Tensor) + assert out.shape == input_size + assert out.dtype == input_dtype + # Ensure that nothing has changed apart from channel order + assert tf.math.reduce_all(tf.math.reduce_sum(input_t, -1) == tf.math.reduce_sum(out, -1)) + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [tf.float32, (32, 32, 3)], + [tf.uint8, (32, 32, 3)], + ], +) +def test_gaussian_noise(input_dtype, input_shape): + transform = T.GaussianNoise(0.0, 1.0) + input_t = tf.random.uniform(input_shape, dtype=tf.float32) + if input_dtype == tf.uint8: + input_t = tf.math.round((255 * input_t)) + input_t = tf.cast(input_t, dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, tf.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + assert tf.math.reduce_any(transformed != input_t) + assert tf.math.reduce_all(transformed >= 0) + if input_dtype == tf.uint8: + assert tf.reduce_all(transformed <= 255) + else: + assert tf.reduce_all(transformed <= 1.0) + + +@pytest.mark.parametrize( + "p,target", + [ + [1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [1, np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)], + [0, np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)], + ], +) +def test_randomhorizontalflip(p, target): + # testing for 2 cases, with flip probability 1 and 0. + transform = T.RandomHorizontalFlip(p) + input_t = np.ones((32, 32, 3)) + input_t[:, :16, :] = 0 + input_t = tf.convert_to_tensor(input_t) + transformed, _target = transform(input_t, target) + assert isinstance(transformed, tf.Tensor) + assert transformed.shape == input_t.shape + assert transformed.dtype == input_t.dtype + # integrity check of targets + assert isinstance(_target, np.ndarray) + assert _target.dtype == np.float32 + if _target.ndim == 2: + if p == 1: + assert np.all(_target == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([1] * 16 + [0] * 16, dtype=tf.float64) + ) + elif p == 0: + assert np.all(_target == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([0] * 16 + [1] * 16, dtype=tf.float64) + ) + else: + if p == 1: + assert np.all(_target == np.array([[[0.9, 0.1], [0.7, 0.1], [0.7, 0.4], [0.9, 0.4]]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([1] * 16 + [0] * 16, dtype=tf.float64) + ) + elif p == 0: + assert np.all(_target == np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([0] * 16 + [1] * 16, dtype=tf.float64) + ) + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [tf.float32, (32, 32, 3)], + [tf.uint8, (32, 32, 3)], + [tf.float32, (64, 32, 3)], + [tf.uint8, (64, 32, 3)], + ], +) +def test_random_shadow(input_dtype, input_shape): + transform = T.RandomShadow((0.2, 0.8)) + input_t = tf.random.uniform(input_shape, dtype=tf.float32) + if input_dtype == tf.uint8: + input_t = tf.math.round((255 * input_t)) + input_t = tf.cast(input_t, dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, tf.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + # The shadow will darken the picture + assert tf.math.reduce_mean(input_t) >= tf.math.reduce_mean(transformed) + assert tf.math.reduce_all(transformed >= 0) + if input_dtype == tf.uint8: + assert tf.reduce_all(transformed <= 255) + else: + assert tf.reduce_all(transformed <= 1.0) + + +@pytest.mark.parametrize( + "p,target", + [ + [1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)], + [1, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)], + [0, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)], + ], +) +def test_random_resize(p, target): + transfo = T.RandomResize(scale_range=(0.3, 1.3), p=p) + assert repr(transfo) == f"RandomResize(scale_range=(0.3, 1.3), p={p})" + + img = tf.random.uniform((64, 64, 3)) + # Apply the transformation + out_img, out_target = transfo(img, target) + assert isinstance(out_img, tf.Tensor) + assert isinstance(out_target, np.ndarray) + # Resize is already well-tested + assert tf.reduce_all(tf.equal(out_img, img)) if p == 0 else out_img.shape != img.shape + assert out_target.shape == target.shape