Skip to content

Commit

Permalink
Merge pull request #514 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Update texify
  • Loading branch information
VikParuchuri authored Jan 29, 2025
2 parents 228a7ba + c85d72b commit 9c740b1
Show file tree
Hide file tree
Showing 15 changed files with 429 additions and 174 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ on: [push]

env:
TORCH_DEVICE: "cpu"
OCR_ENGINE: "surya"

jobs:
benchmark:
runs-on: ubuntu-latest
runs-on: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ Pages have the keys:

- `id` - unique id for the block.
- `block_type` - the type of block. The possible block types can be seen in `marker/schema/__init__.py`. As of this writing, they are ["Line", "Span", "FigureGroup", "TableGroup", "ListGroup", "PictureGroup", "Page", "Caption", "Code", "Figure", "Footnote", "Form", "Equation", "Handwriting", "TextInlineMath", "ListItem", "PageFooter", "PageHeader", "Picture", "SectionHeader", "Table", "Text", "TableOfContents", "Document"]
- `html` - the HTML for the page. Note that this will have recursive references to children. The `content-ref` tags must be replaced with the child content if you want the full html. You can see an example of this at `marker/renderers/__init__.py:BaseRender.extract_block_html`.
- `html` - the HTML for the page. Note that this will have recursive references to children. The `content-ref` tags must be replaced with the child content if you want the full html. You can see an example of this at `marker/output.py:json_to_html`. That function will take in a single block from the json output, and turn it into HTML.
- `polygon` - the 4-corner polygon of the page, in (x1,y1), (x2,y2), (x3, y3), (x4, y4) format. (x1,y1) is the top left, and coordinates go clockwise.
- `children` - the child blocks.

Expand Down
33 changes: 2 additions & 31 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
import os

from marker.settings import settings

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

from typing import List
from PIL import Image
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS

from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor

from texify.model.model import load_model as load_texify_model
from texify.model.processor import load_processor as load_texify_processor
from texify.inference import batch_inference

class TexifyPredictor:
def __init__(self, device=None, dtype=None):
if not device:
device = settings.TORCH_DEVICE_MODEL
if not dtype:
dtype = settings.TEXIFY_DTYPE

self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
self.processor = load_texify_processor()
self.device = device
self.dtype = dtype

def __call__(self, batch_images: List[Image.Image], max_tokens: int):
return batch_inference(
batch_images,
self.model,
self.processor,
max_tokens=max_tokens
)
from surya.texify import TexifyPredictor


def create_model_dict(device=None, dtype=None) -> dict:
Expand Down
20 changes: 19 additions & 1 deletion marker/output.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import json
import os

from bs4 import BeautifulSoup
from pydantic import BaseModel

from marker.renderers.html import HTMLOutput
from marker.renderers.json import JSONOutput
from marker.renderers.json import JSONOutput, JSONBlockOutput
from marker.renderers.markdown import MarkdownOutput
from marker.settings import settings

def json_to_html(block: JSONBlockOutput):
# Utility function to take in json block output and give html for the block.
if not getattr(block, "children", None):
return block.html
else:
child_html = [json_to_html(child) for child in block.children]
child_ids = [child.id for child in block.children]

soup = BeautifulSoup(block.html, "html.parser")
content_refs = soup.find_all("content-ref")
for ref in content_refs:
src_id = ref.attrs["src"]
if src_id in child_ids:
child_soup = BeautifulSoup(child_html[child_ids.index(src_id)], "html.parser")
ref.replace_with(child_soup)
return str(soup)


def output_exists(output_dir: str, fname_base: str):
exts = ["md", "html", "json"]
Expand Down
2 changes: 1 addition & 1 deletion marker/processors/blockquote.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def __call__(self, document: Document):
next_block.blockquote_level += 1
elif len(next_block.structure) >= 2 and (x_indent and y_indent):
next_block.blockquote = True
next_block.blockquote_level = 1
next_block.blockquote_level = 1
85 changes: 8 additions & 77 deletions marker/processors/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
] = 384
] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
Expand Down Expand Up @@ -65,27 +65,7 @@ def __call__(self, document: Document):
continue

block = document.get_block(equation_d["block_id"])
block.html = self.parse_latex_to_html(prediction)

def parse_latex_to_html(self, latex: str):
html_out = ""
try:
latex = self.parse_latex(latex)
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
else:
html_out += f" {el['content']} "
return html_out.strip()
block.html = prediction

def get_batch_size(self):
if self.texify_batch_size is not None:
Expand All @@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))

batch_equations = equation_data[min_idx:max_idx]
max_length = max([eq["token_count"] for eq in batch_equations])
max_length = min(max_length, self.model_max_length)
max_length += self.token_buffer

batch_images = [eq["image"] for eq in batch_equations]

model_output = self.texify_model(
batch_images,
max_tokens=max_length
batch_images
)

for j, output in enumerate(model_output):
token_count = self.get_total_texify_tokens(output)
if token_count >= max_length - 1:
output = ""
token_count = self.get_total_texify_tokens(output.text)
if token_count >= self.model_max_length - 1:
output.text = ""

image_idx = i + j
predictions[image_idx] = output
predictions[image_idx] = output.text
return predictions

def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
return len(tokens["input_ids"])
4 changes: 4 additions & 0 deletions marker/processors/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def list_group_indentation(self, document: Document):
for list_item_id in block.structure:
list_item_block: ListItem = page.get_block(list_item_id)

# This can be a line sometimes
if list_item_block.block_type != BlockTypes.ListItem:
continue

while stack and list_item_block.polygon.x_start <= stack[-1].polygon.x_start + (self.min_x_indent * page.polygon.width):
stack.pop()

Expand Down
1 change: 0 additions & 1 deletion marker/processors/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from marker.schema.groups.list import ListGroup
from marker.schema.groups.table import TableGroup
from marker.schema.registry import get_block_class
from marker.schema.groups.picture import PictureGroup
from marker.schema.groups.figure import FigureGroup


Expand Down
37 changes: 27 additions & 10 deletions marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy
from typing import Annotated, List
from collections import Counter
from PIL import ImageDraw

from ftfy import fix_text
from surya.detection import DetectionPredictor
Expand Down Expand Up @@ -53,6 +52,10 @@ class TableProcessor(BaseProcessor):
int,
"The number of workers to use for pdftext.",
] = 4
row_split_threshold: Annotated[
float,
"The percentage of rows that need to be split across the table before row splitting is active.",
] = 0.5

def __init__(
self,
Expand Down Expand Up @@ -171,10 +174,7 @@ def split_combined_rows(self, tables: List[TableResult]):
# Skip empty tables
continue
unique_rows = sorted(list(set([c.row_id for c in table.cells])))
new_cells = []
shift_up = 0
max_cell_id = max([c.cell_id for c in table.cells])
new_cell_count = 0
row_info = []
for row in unique_rows:
# Cells in this row
# Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
Expand All @@ -201,9 +201,25 @@ def split_combined_rows(self, tables: List[TableResult]):
len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts
])
should_split = should_split_entire_row or should_split_partial_row
if should_split:
for i in range(0, max(line_lens)):
for cell in row_cells:
row_info.append({
"should_split": should_split,
"row_cells": row_cells,
"line_lens": line_lens
})

# Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows.
if sum([r["should_split"] for r in row_info]) / len(row_info) < self.row_split_threshold:
continue

new_cells = []
shift_up = 0
max_cell_id = max([c.cell_id for c in table.cells])
new_cell_count = 0
for row, item_info in zip(unique_rows, row_info):
max_lines = max(item_info["line_lens"])
if item_info["should_split"]:
for i in range(0, max_lines):
for cell in item_info["row_cells"]:
# Calculate height based on number of splits
split_height = cell.bbox[3] - cell.bbox[1]
current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height]
Expand All @@ -226,9 +242,10 @@ def split_combined_rows(self, tables: List[TableResult]):
new_cell_count += 1

# For each new row we add, shift up subsequent rows
shift_up += line_lens[0] - 1
# The max is to account for partial rows
shift_up += max_lines - 1
else:
for cell in row_cells:
for cell in item_info["row_cells"]:
cell.row_id += shift_up
new_cells.append(cell)

Expand Down
18 changes: 15 additions & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from marker.schema.document import Document


def escape_dollars(text):
return text.replace("$", r"\$")

def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()

def get_formatted_table_text(element):

text = []
for content in element.contents:
if content is None:
Expand All @@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(stripped)
text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('<br>')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
text.append(str(content))
content_str = escape_dollars(str(content))
text.append(content_str)

full_text = ""
for i, t in enumerate(text):
Expand Down Expand Up @@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
grid[row_idx + r][col_idx + c] = ''
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
Expand Down Expand Up @@ -176,6 +181,12 @@ def convert_span(self, el, text, convert_as_inline):
else:
return text

def escape(self, text):
text = super().escape(text)
if self.options['escape_dollars']:
text = text.replace('$', r'\$')
return text

class MarkdownOutput(BaseModel):
markdown: str
images: dict
Expand All @@ -198,6 +209,7 @@ def __call__(self, document: Document) -> MarkdownOutput:
escape_misc=False,
escape_underscores=False,
escape_asterisks=False,
escape_dollars=True,
sub_symbol="<sub>",
sup_symbol="<sup>",
inline_math_delimiters=self.inline_math_delimiters,
Expand Down
Loading

0 comments on commit 9c740b1

Please sign in to comment.