Skip to content

Commit

Permalink
Adding save to HF support for async webcrawler
Browse files Browse the repository at this point in the history
Adding support for sync webcrawler
  • Loading branch information
AndreaFrancis committed Dec 13, 2024
1 parent 7524aa7 commit fba3616
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 10 deletions.
19 changes: 13 additions & 6 deletions crawl4ai/async_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy
from .markdown_generation_strategy import MarkdownGenerationStrategy
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy

class BrowserConfig:
"""
Expand Down Expand Up @@ -188,6 +189,7 @@ class CrawlerRunConfig:
Default: None (NoExtractionStrategy is used if None).
chunking_strategy (ChunkingStrategy): Strategy to chunk content before extraction.
Default: RegexChunking().
data_persistence_strategy (DataPersistenceStrategy): Strategy for storing the results. Defaults to SkipDataPersistenceStrategy.
content_filter (RelevantContentFilter or None): Optional filter to prune irrelevant content.
Default: None.
cache_mode (CacheMode or None): Defines how caching is handled.
Expand Down Expand Up @@ -268,11 +270,12 @@ class CrawlerRunConfig:
def __init__(
self,
word_count_threshold: int = MIN_WORD_THRESHOLD ,
extraction_strategy : ExtractionStrategy=None, # Will default to NoExtractionStrategy if None
chunking_strategy : ChunkingStrategy= None, # Will default to RegexChunking if None
extraction_strategy : ExtractionStrategy = None, # Will default to NoExtractionStrategy if None
chunking_strategy : ChunkingStrategy = None, # Will default to RegexChunking if None
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
markdown_generator : MarkdownGenerationStrategy = None,
content_filter=None,
cache_mode=None,
content_filter = None,
cache_mode = None,
session_id: str = None,
bypass_cache: bool = False,
disable_cache: bool = False,
Expand All @@ -285,7 +288,7 @@ def __init__(
only_text: bool = False,
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
prettiify: bool = False,
js_code=None,
js_code = None,
wait_for: str = None,
js_only: bool = False,
wait_until: str = "domcontentloaded",
Expand All @@ -311,6 +314,7 @@ def __init__(
self.word_count_threshold = word_count_threshold
self.extraction_strategy = extraction_strategy
self.chunking_strategy = chunking_strategy
self.data_persistence_strategy = data_persistence_strategy
self.markdown_generator = markdown_generator
self.content_filter = content_filter
self.cache_mode = cache_mode
Expand Down Expand Up @@ -354,7 +358,9 @@ def __init__(
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy")
if self.chunking_strategy is not None and not isinstance(self.chunking_strategy, ChunkingStrategy):
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy")

if self.data_persistence_strategy is not None and not isinstance(data_persistence_strategy, DataPersistenceStrategy):
raise ValueError("data_persistence_strategy must be an instance of DataPersistenceStrategy")

# Set default chunking strategy if None
if self.chunking_strategy is None:
from .chunking_strategy import RegexChunking
Expand All @@ -367,6 +373,7 @@ def from_kwargs(kwargs: dict) -> "CrawlerRunConfig":
word_count_threshold=kwargs.get("word_count_threshold", 200),
extraction_strategy=kwargs.get("extraction_strategy"),
chunking_strategy=kwargs.get("chunking_strategy"),
data_persistence_strategy=kwargs.get("data_persistence_strategy"),
markdown_generator=kwargs.get("markdown_generator"),
content_filter=kwargs.get("content_filter"),
cache_mode=kwargs.get("cache_mode"),
Expand Down
10 changes: 8 additions & 2 deletions crawl4ai/async_webcrawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
URL_LOG_SHORTEN_LENGTH
)
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy
from .utils import (
sanitize_input_encode,
InvalidCSSSelectorError,
Expand Down Expand Up @@ -153,6 +154,7 @@ async def arun(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
content_filter: RelevantContentFilter = None,
cache_mode: Optional[CacheMode] = None,
# Deprecated cache parameters
Expand Down Expand Up @@ -206,7 +208,7 @@ async def arun(
if crawler_config is not None:
if any(param is not None for param in [
word_count_threshold, extraction_strategy, chunking_strategy,
content_filter, cache_mode, css_selector, screenshot, pdf
data_persistence_strategy, content_filter, cache_mode, css_selector, screenshot, pdf
]):
self.logger.warning(
message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.",
Expand All @@ -219,6 +221,7 @@ async def arun(
"word_count_threshold": word_count_threshold,
"extraction_strategy": extraction_strategy,
"chunking_strategy": chunking_strategy,
"data_persistence_strategy": data_persistence_strategy,
"content_filter": content_filter,
"cache_mode": cache_mode,
"bypass_cache": bypass_cache,
Expand Down Expand Up @@ -350,6 +353,9 @@ async def arun(
}
)

if config.data_persistence_strategy:
crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result)

# Update cache if appropriate
if cache_context.should_write() and not bool(cached_result):
await async_db_manager.acache_url(crawl_result)
Expand Down Expand Up @@ -530,6 +536,7 @@ async def arun_many(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
content_filter: RelevantContentFilter = None,
cache_mode: Optional[CacheMode] = None,
bypass_cache: bool = False,
Expand Down Expand Up @@ -683,4 +690,3 @@ async def aget_cache_size(self):
"""Get the total number of cached items."""
return await async_db_manager.aget_total_count()


152 changes: 152 additions & 0 deletions crawl4ai/data_persistence_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from abc import ABC, abstractmethod
from .models import CrawlResult
import json
import re
from datasets import Dataset
from huggingface_hub import DatasetCard
from typing import Any


class DataPersistenceStrategy(ABC):
"""
Abstract base class for implementing data persistence strategies.
"""

@abstractmethod
def save(self, result: CrawlResult) -> dict[str, Any]:
"""
Save the given crawl result using a specific persistence strategy.
Args:
result (CrawlResult): The crawl result containing data to persist.
Returns:
dict[str, Any]: A dictionary representing the outcome details of the persistence operation.
"""
pass


class SkipDataPersistenceStrategy(DataPersistenceStrategy):
def save(self, result: CrawlResult) -> dict[str, Any]:
return None


DATASET_CARD_TEMPLATE = """
---
tags:
- crawl4ai
- crawl
---
**Source of the data:**
The dataset was generated using [Crawl4ai](https://crawl4ai.com/mkdocs/) library from {url}.
"""


class HFDataPersistenceStrategy(DataPersistenceStrategy):
"""
A persistence strategy for uploading extracted content or markdown from crawl results to the Hugging Face Hub.
This strategy converts the extracted content or markdown into a Hugging Face Dataset
and uploads it to a specified repository on the Hub.
Args:
repo_id (str): The repository ID on the Hugging Face Hub.
private (bool): Whether the repository should be private.
card (str, optional): The card information for the dataset. Defaults to None.
token (str, optional): The authentication token for the Hugging Face Hub. Defaults to None.
logger (Logger, optional): Logger instance for logging messages. Defaults to None.
**kwargs: Additional keyword arguments.
"""

def __init__(
self, repo_id: str, private: bool, card: str = None, token=None, **kwargs
):
self.repo_id = repo_id
self.private = private
self.card = card
self.verbose = kwargs.get("verbose", False)
self.token = token

def save(self, result: CrawlResult) -> dict[str, Any]:
"""
Uploads extracted content or markdown from the given crawl result to the Hugging Face Hub.
Args:
result (CrawlResult): The crawl result containing extracted content or markdown to upload.
Returns:
dict[str, Any]: A dictionary with the repository ID and dataset split name.
Raises:
ValueError: If neither extracted content nor markdown is present in the result.
TypeError: If extracted content or markdown is not a string.
Notes:
- Extracted content should be a JSON string containing a list of dictionaries.
- If extracted content is invalid, raw markdown will be used as a fallback.
- The repository ID and dataset split name are returned upon successful upload.
"""
if not (result.extracted_content or result.markdown):
raise ValueError("No extracted content or markdown present.")

if result.extracted_content and not isinstance(result.extracted_content, str):
raise TypeError("Extracted content must be a string.")

if result.markdown and not isinstance(result.markdown, str):
raise TypeError("Markdown must be a string.")

records = self._prepare_records(result)

if self.verbose:
print(
f"[LOG] 🔄 Successfully converted extracted content to JSON records: {len(records)} records found"
)

ds = Dataset.from_list(records)
sanitized_split_name = re.sub(r"[^a-zA-Z0-9_]", "_", result.url)
commit_info = ds.push_to_hub(
repo_id=self.repo_id,
private=self.private,
token=self.token,
split=sanitized_split_name,
)

repo_id = commit_info.repo_url.repo_id
self._push_dataset_card(repo_id, result.url)

if self.verbose:
print(
f"[LOG] ✅ Data has been successfully pushed to the Hugging Face Hub. Repository ID: {repo_id}"
)

return {"repo_id": repo_id, "split": sanitized_split_name}

def _prepare_records(self, result: CrawlResult) -> list[dict[str, Any]]:
if result.extracted_content:
try:
records = json.loads(result.extracted_content)
if not isinstance(records, list) or not all(
isinstance(rec, dict) for rec in records
):
raise ValueError(
"Extracted content must be a JSON list of dictionaries."
)
except json.JSONDecodeError as e:
if self.verbose:
print(f"[LOG] ⚠️ Failed to parse extracted content as JSON: {e}")
records = [{"extracted_content": result.extracted_content}]
else:
records = [{"markdown": result.markdown}]

return records

def _push_dataset_card(self, repo_id: str, url: str) -> None:
card_content = self.card or DATASET_CARD_TEMPLATE.format(url=url)
DatasetCard(content=card_content).push_to_hub(
repo_id=repo_id, repo_type="dataset", token=self.token
)
if self.verbose:
print(f"[LOG] 🔄 Dataset card successfully pushed to repository: {repo_id}")
1 change: 1 addition & 0 deletions crawl4ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CrawlResult(BaseModel):
session_id: Optional[str] = None
response_headers: Optional[dict] = None
status_code: Optional[int] = None
storage_metadata: Optional[dict] = None

class AsyncCrawlResponse(BaseModel):
html: str
Expand Down
10 changes: 9 additions & 1 deletion crawl4ai/web_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List
from concurrent.futures import ThreadPoolExecutor
from .content_scraping_strategy import WebScrapingStrategy
from .data_persistence_strategy import DataPersistenceStrategy, SkipDataPersistenceStrategy
from .config import *
import warnings
import json
Expand Down Expand Up @@ -109,6 +110,7 @@ def run(
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
data_persistence_strategy: DataPersistenceStrategy = SkipDataPersistenceStrategy(),
bypass_cache: bool = False,
css_selector: str = None,
screenshot: bool = False,
Expand All @@ -123,7 +125,9 @@ def run(
raise ValueError("Unsupported extraction strategy")
if not isinstance(chunking_strategy, ChunkingStrategy):
raise ValueError("Unsupported chunking strategy")

if not isinstance(data_persistence_strategy, DataPersistenceStrategy):
raise ValueError("Unsupported data persistence strategy")

word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)

cached = None
Expand Down Expand Up @@ -157,6 +161,10 @@ def run(

crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs)
crawl_result.success = bool(html)

if data_persistence_strategy:
crawl_result.storage_metadata = data_persistence_strategy.save(crawl_result)

return crawl_result
except Exception as e:
if not hasattr(e, "msg"):
Expand Down
22 changes: 22 additions & 0 deletions docs/examples/async_webcrawler_md_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio
from crawl4ai import AsyncWebCrawler
from crawl4ai.data_persistence_strategy import HFDataPersistenceStrategy


async def main():
async with AsyncWebCrawler(verbose=True) as crawler:
persistence_strategy = HFDataPersistenceStrategy(
repo_id="crawl4ai_hf_page_md", private=False, verbose=True
)

result = await crawler.arun(
url="https://huggingface.co/",
data_persistence_strategy=persistence_strategy,
)

print(f"Successfully crawled markdown: {result.markdown}")
print(f"Persistence details: {result.storage_metadata}")


# Run the async main function
asyncio.run(main())
Loading

0 comments on commit fba3616

Please sign in to comment.