forked from huggingface/datatrove
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
66 lines (52 loc) · 2.21 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datatrove.data import DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.logging import logger
from datatrove.utils.typeshelper import StatHints
class BaseExtractor(PipelineStep):
"""Base Extractor module. Extractors extract text from html or other non-plain text formats"""
type = "🛢 - EXTRAC"
@abstractmethod
def __init__(self, timeout: float = 0.1):
"""
Args:
timeout: the timeout for extraction, per document, in seconds
"""
super().__init__()
self.timeout = timeout
@abstractmethod
def extract(self, text: str) -> str:
"""abstract method that actually implements the extraction, e.g. trafilatura.
Args:
text: str: non-plain text
Returns: extracted plain text
"""
pass
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""Iterates through each document in data and calls `timeout_extract` on it.
Args:
data: DocumentsPipeline:
rank: int: (Default value = 0)
world_size: int: (Default value = 1)
Returns:
"""
with ThreadPoolExecutor() as executor: # more reliable than using signal for timeouts
for doc in data:
self.stat_update(StatHints.total)
with self.track_time():
future = executor.submit(self.extract, doc.text)
try:
doc.text = future.result(timeout=self.timeout)
except TimeoutError:
logger.warning("⏰ Timeout while cleaning record text. Skipping record.")
continue
except Exception as e:
logger.warning(f'❌ Error "{e}" while cleaning record text. Skipping record.')
continue
if doc.text:
self.stat_update(StatHints.forwarded)
self.update_doc_stats(doc)
yield doc
else:
self.stat_update(StatHints.dropped)