Skip to content

Commit

Permalink
[Sync] improve reconstitution (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Oct 17, 2024
1 parent 277626c commit d17146c
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 64 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-ast
- id: check-yaml
Expand All @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.6.9
hooks:
- id: ruff
args: [ --fix ]
Expand Down
63 changes: 42 additions & 21 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image

from onnxtr.io import DocumentFile
from onnxtr.models import ocr_predictor
from onnxtr.models import from_hub, ocr_predictor
from onnxtr.models.predictor import OCRPredictor
from onnxtr.utils.visualization import visualize_page

Expand All @@ -35,12 +35,17 @@
"parseq",
]

CUSTOM_RECO_ARCHS: List[str] = [
"Felix92/onnxtr-parseq-multilingual-v1",
]


def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
export_as_straight_boxes: bool,
detect_language: bool,
load_in_8_bit: bool,
bin_thresh: float,
Expand All @@ -58,6 +63,7 @@ def load_predictor(
disable_crop_orientation: whether to disable crop orientation or not
disable_page_orientation: whether to disable page orientation or not
straighten_pages: whether to straighten rotated pages or not
export_as_straight_boxes: whether to export straight boxes
detect_language: whether to detect the language of the text
load_in_8_bit: whether to load the image in 8 bit mode
bin_thresh: binarization threshold for the segmentation map
Expand All @@ -68,13 +74,13 @@ def load_predictor(
instance of OCRPredictor
"""
predictor = ocr_predictor(
det_arch,
reco_arch,
det_arch=det_arch,
reco_arch=reco_arch if reco_arch not in CUSTOM_RECO_ARCHS else from_hub(reco_arch),
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
detect_language=detect_language,
load_in_8_bit=load_in_8_bit,
export_as_straight_boxes=straighten_pages,
export_as_straight_boxes=export_as_straight_boxes,
detect_orientation=not assume_straight_pages,
disable_crop_orientation=disable_crop_orientation,
disable_page_orientation=disable_page_orientation,
Expand Down Expand Up @@ -132,6 +138,7 @@ def analyze_page(
disable_crop_orientation: bool,
disable_page_orientation: bool,
straighten_pages: bool,
export_as_straight_boxes: bool,
detect_language: bool,
load_in_8_bit: bool,
bin_thresh: float,
Expand All @@ -149,14 +156,15 @@ def analyze_page(
disable_crop_orientation: whether to disable crop orientation or not
disable_page_orientation: whether to disable page orientation or not
straighten_pages: whether to straighten rotated pages or not
export_as_straight_boxes: whether to export straight boxes
detect_language: whether to detect the language of the text
load_in_8_bit: whether to load the image in 8 bit mode
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
Returns:
-------
input image, segmentation heatmap, output image, OCR output
input image, segmentation heatmap, output image, OCR output, synthesized page
"""
if uploaded_file is None:
return None, "Please upload a document", None, None, None
Expand All @@ -165,19 +173,23 @@ def analyze_page(
doc = DocumentFile.from_pdf(uploaded_file)
else:
doc = DocumentFile.from_images(uploaded_file)
try:
page = doc[page_idx - 1]
except IndexError:
page = doc[-1]

page = doc[page_idx - 1]
img = page

predictor = load_predictor(
det_arch,
reco_arch,
assume_straight_pages,
straighten_pages,
detect_language,
load_in_8_bit,
bin_thresh,
box_thresh,
det_arch=det_arch,
reco_arch=reco_arch,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=export_as_straight_boxes,
detect_language=detect_language,
load_in_8_bit=load_in_8_bit,
bin_thresh=bin_thresh,
box_thresh=box_thresh,
disable_crop_orientation=disable_crop_orientation,
disable_page_orientation=disable_page_orientation,
)
Expand All @@ -194,7 +206,12 @@ def analyze_page(

out_img = matplotlib_to_pil(fig)

return img, seg_heatmap, out_img, page_export
if assume_straight_pages or straighten_pages:
synthesized_page = out.synthesize()[0]
else:
synthesized_page = None

return img, seg_heatmap, out_img, page_export, synthesized_page


with gr.Blocks(fill_height=True) as demo:
Expand Down Expand Up @@ -226,11 +243,14 @@ def analyze_page(
upload = gr.File(label="Upload File [JPG | PNG | PDF]", file_types=["pdf", "jpg", "png"])
page_selection = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Page selection")
det_model = gr.Dropdown(choices=DET_ARCHS, value=DET_ARCHS[0], label="Text detection model")
reco_model = gr.Dropdown(choices=RECO_ARCHS, value=RECO_ARCHS[0], label="Text recognition model")
reco_model = gr.Dropdown(
choices=RECO_ARCHS + CUSTOM_RECO_ARCHS, value=RECO_ARCHS[0], label="Text recognition model"
)
assume_straight = gr.Checkbox(value=True, label="Assume straight pages")
disable_crop_orientation = gr.Checkbox(value=False, label="Disable crop orientation")
disable_page_orientation = gr.Checkbox(value=False, label="Disable page orientation")
straighten = gr.Checkbox(value=False, label="Straighten pages")
export_as_straight_boxes = gr.Checkbox(value=False, label="Export as straight boxes")
det_language = gr.Checkbox(value=False, label="Detect language")
load_in_8_bit = gr.Checkbox(value=False, label="Load 8-bit quantized models")
binarization_threshold = gr.Slider(
Expand All @@ -243,11 +263,11 @@ def analyze_page(
input_image = gr.Image(label="Input page", width=600)
segmentation_heatmap = gr.Image(label="Segmentation heatmap", width=600)
output_image = gr.Image(label="Output page", width=600)
with gr.Column(scale=2):
with gr.Row():
gr.Markdown("### OCR output")
with gr.Row():
with gr.Row():
with gr.Column(scale=3):
ocr_output = gr.JSON(label="OCR output", render=True, scale=1)
with gr.Column(scale=3):
synthesized_page = gr.Image(label="Synthesized page", width=600)

analyze_button.click(
analyze_page,
Expand All @@ -260,12 +280,13 @@ def analyze_page(
disable_crop_orientation,
disable_page_orientation,
straighten,
export_as_straight_boxes,
det_language,
load_in_8_bit,
binarization_threshold,
box_threshold,
],
outputs=[input_image, segmentation_heatmap, output_image, ocr_output],
outputs=[input_image, segmentation_heatmap, output_image, ocr_output, synthesized_page],
)

demo.launch(inbrowser=True, allowed_paths=["./data/logo.jpg"])
2 changes: 1 addition & 1 deletion demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-e git+https://github.com/felixdittrich92/OnnxTR.git#egg=onnxtr[cpu-headless,viz]
gradio>=4.37.1,<6.0.0
gradio>=4.37.1,<5.0.0
14 changes: 11 additions & 3 deletions onnxtr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(
if geometry is None:
# Check whether this is a rotated or straight box
box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator]
geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[misc]

super().__init__(words=words)
self.geometry = geometry
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
box_resolution_fn = (
resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox
)
geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator]
geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore

super().__init__(lines=lines, artefacts=artefacts)
self.geometry = geometry
Expand Down Expand Up @@ -294,6 +294,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **
def synthesize(self, **kwargs) -> np.ndarray:
"""Synthesize the page from the predictions
Args:
----
**kwargs: keyword arguments passed to the `synthesize_page` method
Returns
-------
synthesized page
Expand Down Expand Up @@ -442,11 +446,15 @@ def show(self, **kwargs) -> None:
def synthesize(self, **kwargs) -> List[np.ndarray]:
"""Synthesize all pages from their predictions
Args:
----
**kwargs: keyword arguments passed to the `Page.synthesize` method
Returns
-------
list of synthesized pages
"""
return [page.synthesize() for page in self.pages]
return [page.synthesize(**kwargs) for page in self.pages]

def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]:
"""Export the document as XML (hOCR-format)
Expand Down
Loading

0 comments on commit d17146c

Please sign in to comment.