-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
484 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import re | ||
import subprocess | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import latex2mathml.converter | ||
|
||
from marker.renderers.markdown import MarkdownRenderer | ||
|
||
class MarkdownCleaner: | ||
def __init__(self): | ||
pass | ||
|
||
def __call__(self, markdown): | ||
markdown = self.normalize_markdown(markdown) # Use pandoc to normalize | ||
|
||
# Replace math expressions with latexml | ||
pattern = r'(?<!\\)\$(?:\$([^$]+)\$\$|\s*([^$\n]+?)\s*\$)' | ||
markdown = re.sub(pattern, self.standardize_math, markdown) | ||
|
||
# Replace image urls with a generic tag | ||
pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)' | ||
markdown = re.sub(pattern, r'![link]', markdown) | ||
|
||
# Clean up stray html tags | ||
markdown = markdown.replace("<br>", "\n") | ||
markdown = re.sub(r"<sub>(.*?)</sub>", r"\1", markdown) | ||
markdown = re.sub(r"<sup>(.*?)</sup>", r"\1", markdown) | ||
markdown = re.sub(r"<span.*?>(.*?)</span>", r"\1", markdown) # Remove span tags and keep content | ||
|
||
# Clean up markdown formatting | ||
markdown = re.sub(r"\s+", " ", markdown) | ||
markdown = re.sub(r"\n+", "\n", markdown) | ||
markdown = re.sub("\\.+", ".", | ||
markdown) # Replace repeated periods with a single period, like in table of contents | ||
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header | ||
markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly | ||
return markdown.strip().lower() | ||
|
||
@staticmethod | ||
def normalize_markdown(md_text: str) -> str: | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
dirpath = Path(tmp_dir) | ||
input_file = dirpath / 'input.md' | ||
input_file.write_text(md_text, encoding='utf-8') | ||
|
||
# Markdown to HTML | ||
html_file = dirpath / 'temp.html' | ||
subprocess.run( | ||
[ | ||
'pandoc', | ||
str(input_file), | ||
'-f', 'markdown+tex_math_dollars', | ||
'-t', 'html', | ||
'-o', str(html_file), | ||
'--quiet' | ||
], | ||
check=True | ||
) | ||
|
||
# HTML to Markdown | ||
output_file = dirpath / 'output.md' | ||
subprocess.run( | ||
[ | ||
'pandoc', | ||
str(html_file), | ||
'-f', 'html', | ||
'-t', 'markdown+tex_math_dollars', | ||
'-o', str(output_file), | ||
'--quiet' | ||
], | ||
check=True | ||
) | ||
|
||
# Read back the normalized Markdown | ||
normalized_md = output_file.read_text(encoding='utf-8') | ||
|
||
return normalized_md | ||
|
||
def standardize_math(self, match): | ||
try: | ||
delim = "$$" if match.group(0).startswith('$$') else "$" | ||
math_content = match.group(1) or match.group(2) | ||
if delim == "$$": | ||
math_content = latex2mathml.converter.convert(math_content) | ||
else: | ||
math_content = self.clean_latex(math_content) | ||
return f'{delim}{math_content}{delim}' | ||
except Exception as e: | ||
print(f"Failed to standardize math expression: {match.group(0)} with error: {e}") | ||
return match.group(0) | ||
|
||
@staticmethod | ||
def clean_latex(latex_str): | ||
latex_str = re.sub(r'\s+', ' ', latex_str.strip()) | ||
for tag in [r'\\text', r'\\mathrm', r'\\mathbf', r'\\textbf']: | ||
latex_str = re.sub(tag + r'\{([^}]+)\}', r'\1', latex_str) | ||
|
||
replacements = { | ||
'\\times': '*', | ||
'\\cdot': '*', | ||
'\\div': '/', | ||
'\\le': '<=', | ||
'\\ge': '>=', | ||
'\\neq': '!=', | ||
'\\to': '\\rightarrow', | ||
} | ||
|
||
for old, new in replacements.items(): | ||
latex_str = latex_str.replace(old, new) | ||
|
||
return latex_str | ||
|
||
|
||
def convert_to_md(html): | ||
md = MarkdownRenderer() | ||
markdown = md.md_cls.convert(html) | ||
return markdown | ||
|
||
def clean_input(markdown): | ||
cleaner = MarkdownCleaner() | ||
return cleaner(markdown) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import subprocess | ||
import tempfile | ||
import pypdfium2 as pdfium | ||
from typing import Dict | ||
from collections import defaultdict | ||
import re | ||
import io | ||
import json | ||
|
||
from PIL import Image | ||
import datasets | ||
import markdown2 | ||
from playwright.sync_api import sync_playwright | ||
|
||
from benchmarks.overall.schema import FullResult | ||
|
||
def convert_to_html(md: str): | ||
block_placeholders = [] | ||
inline_placeholders = [] | ||
|
||
# Add placeholders for the math | ||
def block_sub(match): | ||
content = match.group(1) | ||
placeholder = f"1BLOCKMATH{len(block_placeholders)}1" | ||
block_placeholders.append((placeholder, f"$${content}$$")) | ||
return placeholder | ||
|
||
def inline_sub(match): | ||
content = match.group(1) | ||
placeholder = f"1INLINEMATH{len(inline_placeholders)}1" | ||
inline_placeholders.append((placeholder, f"${content}$")) | ||
return placeholder | ||
|
||
md = re.sub(r'\${2}(.*?)\${2}', block_sub, md, flags=re.DOTALL) | ||
md = re.sub(r'\$(.*?)\$', inline_sub, md) | ||
|
||
html = markdown2.markdown(md, extras=['tables']) | ||
|
||
# Replace placeholders | ||
for placeholder, math_str in block_placeholders: | ||
html = html.replace(placeholder, math_str) | ||
for placeholder, math_str in inline_placeholders: | ||
html = html.replace(placeholder, math_str) | ||
|
||
return html | ||
|
||
|
||
def markdown_to_image(md: str) -> Image.Image: | ||
html = convert_to_html(md) | ||
with sync_playwright() as p: | ||
browser = p.chromium.launch() | ||
page = browser.new_page() | ||
page.set_content(f""" | ||
<head> | ||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" integrity="sha384-zh0CIslj+VczCZtlzBcjt5ppRcsAmDnRem7ESsYwWwg3m/OaJ2l4x7YBZl9Kxxib" crossorigin="anonymous"> | ||
<!-- The loading of KaTeX is deferred to speed up page rendering --> | ||
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" integrity="sha384-Rma6DA2IPUwhNxmrB/7S3Tno0YY7sFu9WSYMCuulLhIqYSGZ2gKCJWIqhBWqMQfh" crossorigin="anonymous"></script> | ||
<!-- To automatically render math in text elements, include the auto-render extension: --> | ||
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/contrib/auto-render.min.js" integrity="sha384-hCXGrW6PitJEwbkoStFjeJxv+fSOOQKOPbJxSfM6G5sWZjAyWhXiTIIAmQqnlLlh" crossorigin="anonymous"></script> | ||
</head> | ||
<body> | ||
{html} | ||
<script> | ||
renderMathInElement(document.body, {{ | ||
delimiters: [ | ||
{{left: '$$', right: '$$', display: true}}, | ||
{{left: '$', right: '$', display: false}} | ||
] | ||
}}); | ||
</script> | ||
</body> | ||
""") | ||
page.set_viewport_size({"width": 1200, "height": 800}) | ||
page.wait_for_timeout(500) # Wait for KaTeX to render | ||
screenshot_bytes = page.screenshot(full_page=True) | ||
browser.close() | ||
|
||
return Image.open(io.BytesIO(screenshot_bytes)) | ||
|
||
|
||
def build_dataset(ds: datasets.Dataset, all_scores: Dict[str, FullResult]) -> datasets.Dataset: | ||
# Get all the dataset indices that went through inference | ||
full_idxs = None | ||
for method in all_scores: | ||
result_idxs = list(all_scores[method]["raw_scores"].keys()) | ||
if full_idxs is None: | ||
full_idxs = sorted(result_idxs) | ||
else: | ||
full_idxs = [f for f in full_idxs if f in result_idxs] | ||
|
||
ds_rows = defaultdict(dict) | ||
for idx in full_idxs: | ||
row = ds[idx] # img, gt_blocks, classification, language, uuid | ||
for method in all_scores: | ||
method_row = all_scores[method]["raw_scores"][idx] | ||
ds_rows[idx].update({ | ||
f"{method}_score": method_row["overall_score"], | ||
f"{method}_markdown": method_row["markdown"], | ||
f"{method}_image": markdown_to_image(method_row["markdown"]), | ||
f"{method}_time": method_row["time"] | ||
}) | ||
gt_md = "\n\n".join([clean_input(convert_to_md(block)) for block in json.loads(row["gt_blocks"])]) | ||
ds_rows[idx].update({ | ||
"gt_markdown": gt_md, | ||
"gt_image": markdown_to_image(gt_md) | ||
}) | ||
out_dataset = datasets.Dataset.from_list([ds_rows[k] for k in full_idxs]) | ||
return out_dataset | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.