Skip to content

Commit

Permalink
Support reading settings from pyproject.toml
Browse files Browse the repository at this point in the history
Also a lot more tests.
  • Loading branch information
tedivm committed Mar 25, 2024
1 parent 731fc7f commit 5b3afcd
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 62 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/
cover/

# Translations
Expand Down
41 changes: 33 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@

Paracelsus generates Entity Relationship Diagrams by reading your SQLAlchemy models.

* ERDs can be injected into documentation as [Mermaid Diagrams](https://mermaid.js.org/).
* Paracelsus can be run in CICD to check that databases are up to date.
* ERDs can be created as files in either [Dot](https://graphviz.org/doc/info/lang.html) or Mermaid format.
* DOT files can be used to generate SVG or PNG files, or edited in [GraphViz](https://graphviz.org/) or other editors.

- [Paracelsus](#paracelsus)
- [Features](#features)
- [Usage](#usage)
- [Installation](#installation)
- [Basic CLI Usage](#basic-cli-usage)
- [Importing Models](#importing-models)
- [Generate Mermaid Diagrams](#generate-mermaid-diagrams)
- [Inject Mermaid Diagrams](#inject-mermaid-diagrams)
- [Creating Images](#creating-images)
- [pyproject.toml](#pyprojecttoml)
- [Sponsorship](#sponsorship)

## Features

- ERDs can be injected into documentation as [Mermaid Diagrams](https://mermaid.js.org/).
- Paracelsus can be run in CICD to check that databases are up to date.
- ERDs can be created as files in either [Dot](https://graphviz.org/doc/info/lang.html) or Mermaid format.
- DOT files can be used to generate SVG or PNG files, or edited in [GraphViz](https://graphviz.org/) or other editors.

## Usage

Expand All @@ -29,9 +42,9 @@ paracelsus --help

It has three commands:

* `version` outputs the version of the currently installed `paracelsus` cli.
* `graph` generates a graph and outputs it to `stdout`.
* `inject` inserts the graph into a markdown file.
- `version` outputs the version of the currently installed `paracelsus` cli.
- `graph` generates a graph and outputs it to `stdout`.
- `inject` inserts the graph into a markdown file.

### Importing Models

Expand Down Expand Up @@ -161,6 +174,18 @@ To create a PNG file:
![Alt text](./docs/example.png "a title")


### pyproject.toml

The settings for your project can be saved directly in the `pyprojects.toml` file of your project.

```toml
[tool.paracelsus]
base = "example.base:Base"
imports = [
"example.models"
]
```

## Sponsorship

This project is developed by [Robert Hafner](https://blog.tedivm.com) If you find this project useful please consider sponsoring me using Github!
Expand Down
70 changes: 20 additions & 50 deletions paracelsus/cli.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
import importlib
import os
import re
import sys
from enum import Enum
from pathlib import Path
from typing import List
from typing import Any, Dict, List, Optional

import typer
from typing_extensions import Annotated

from .transformers.dot import Dot
from .transformers.mermaid import Mermaid
from .graph import get_graph_string, transformers
from .pyproject import get_pyproject_settings

app = typer.Typer()

transformers = {
"mmd": Mermaid,
"mermaid": Mermaid,
"dot": Dot,
"gv": Dot,
}


class Formats(str, Enum):
mermaid = "mermaid"
Expand All @@ -29,49 +20,22 @@ class Formats(str, Enum):
gv = "gv"


def get_graph_string(
base_class_path: str,
import_module: List[str],
python_dir: List[Path],
format: str,
) -> str:
# Update the PYTHON_PATH to allow more module imports.
sys.path.append(str(os.getcwd()))
for dir in python_dir:
sys.path.append(str(dir))

# Import the base class so the metadata class can be extracted from it.
# The metadata class is passed to the transformer.
module_path, class_name = base_class_path.split(":", 2)
base_module = importlib.import_module(module_path)
base_class = getattr(base_module, class_name)
metadata = base_class.metadata

# The modules holding the model classes have to be imported to get put in the metaclass model registry.
# These modules aren't actually used in any way, so they are discarded.
# They are also imported in scope of this function to prevent namespace pollution.
for module in import_module:
if ":*" in module:
# Sure, execs are gross, but this is the only way to dynamically import wildcards.
exec(f"from {module[:-2]} import *")
else:
importlib.import_module(module)

# Grab a transformer.
if format not in transformers:
raise ValueError(f"Unknown Format: {format}")
transformer = transformers[format]

# Save the graph structure to string.
return str(transformer(metadata))
def get_base_class(base_class_path: str | None, settings=Dict[str, Any] | None) -> str:
if base_class_path:
return base_class_path
if not settings:
raise ValueError("`base_class_path` argument must be passed if no pyproject.toml file is present.")
if "base" not in settings:
raise ValueError("`base_class_path` argument must be passed if not defined in pyproject.toml.")
return settings["base"]


@app.command(help="Create the graph structure and print it to stdout.")
def graph(
base_class_path: Annotated[
str,
Optional[str],
typer.Argument(help="The SQLAlchemy base class used by the database to graph."),
],
] = None,
import_module: Annotated[
List[str],
typer.Option(
Expand All @@ -92,9 +56,15 @@ def graph(
Formats, typer.Option(help="The file format to output the generated graph to.")
] = Formats.mermaid,
):
settings = get_pyproject_settings()
base_class = get_base_class(base_class_path, settings)

if settings and "imports" in settings:
import_module.extend(settings["imports"])

typer.echo(
get_graph_string(
base_class_path=base_class_path,
base_class_path=base_class,
import_module=import_module,
python_dir=python_dir,
format=format.value,
Expand Down
52 changes: 52 additions & 0 deletions paracelsus/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import importlib
import os
import sys
from pathlib import Path
from typing import List

from .transformers.dot import Dot
from .transformers.mermaid import Mermaid

transformers = {
"mmd": Mermaid,
"mermaid": Mermaid,
"dot": Dot,
"gv": Dot,
}


def get_graph_string(
base_class_path: str,
import_module: List[str],
python_dir: List[Path],
format: str,
) -> str:
# Update the PYTHON_PATH to allow more module imports.
sys.path.append(str(os.getcwd()))
for dir in python_dir:
sys.path.append(str(dir))

# Import the base class so the metadata class can be extracted from it.
# The metadata class is passed to the transformer.
module_path, class_name = base_class_path.split(":", 2)
base_module = importlib.import_module(module_path)
base_class = getattr(base_module, class_name)
metadata = base_class.metadata

# The modules holding the model classes have to be imported to get put in the metaclass model registry.
# These modules aren't actually used in any way, so they are discarded.
# They are also imported in scope of this function to prevent namespace pollution.
for module in import_module:
if ":*" in module:
# Sure, execs are gross, but this is the only way to dynamically import wildcards.
exec(f"from {module[:-2]} import *")
else:
importlib.import_module(module)

# Grab a transformer.
if format not in transformers:
raise ValueError(f"Unknown Format: {format}")
transformer = transformers[format]

# Save the graph structure to string.
return str(transformer(metadata))
16 changes: 16 additions & 0 deletions paracelsus/pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import tomllib
from pathlib import Path
from typing import Any, Dict


def get_pyproject_settings(dir: Path = Path(os.getcwd())) -> Dict[str, Any] | None:
pyproject = dir / "pyproject.toml"

if not pyproject.exists():
return None

with open(pyproject, "rb") as f:
data = tomllib.load(f)

return data.get("tool", {}).get("paracelsus", None)
8 changes: 8 additions & 0 deletions tests/assets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Test Directory

Please ignore.

## Schema

<!-- BEGIN_SQLALCHEMY_DOCS -->
<!-- END_SQLALCHEMY_DOCS -->
Empty file.
3 changes: 3 additions & 0 deletions tests/assets/example/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sqlalchemy.orm import declarative_base

Base = declarative_base()
37 changes: 37 additions & 0 deletions tests/assets/example/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from datetime import UTC, datetime
from uuid import uuid4

from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, Uuid
from sqlalchemy.orm import mapped_column

from .base import Base


class User(Base):
__tablename__ = "users"

id = mapped_column(Uuid, primary_key=True, default=uuid4())
display_name = mapped_column(String(100))
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))


class Post(Base):
__tablename__ = "posts"

id = mapped_column(Uuid, primary_key=True, default=uuid4())
author = mapped_column(ForeignKey(User.id), nullable=False)
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))
live = mapped_column(Boolean, default=False)
content = mapped_column(Text, default="")


class Comment(Base):
__tablename__ = "comments"

id = mapped_column(Uuid, primary_key=True, default=uuid4())
post = mapped_column(Uuid, ForeignKey(Post.id), default=uuid4())
author = mapped_column(ForeignKey(User.id), nullable=False)
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))
live = mapped_column(Boolean, default=False)
content = mapped_column(Text, default="")
content = mapped_column(Text, default="")
5 changes: 5 additions & 0 deletions tests/assets/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.paracelsus]
base = "example.base:Base"
imports = [
"example.models"
]
15 changes: 11 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime
import os
from datetime import UTC, datetime
from pathlib import Path
from uuid import uuid4

import pytest
Expand All @@ -15,14 +17,14 @@ class User(Base):

id = mapped_column(Uuid, primary_key=True, default=uuid4())
display_name = mapped_column(String(100))
created = mapped_column(DateTime, nullable=False, default=datetime.utcnow())
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))

class Post(Base):
__tablename__ = "posts"

id = mapped_column(Uuid, primary_key=True, default=uuid4())
author = mapped_column(ForeignKey(User.id), nullable=False)
created = mapped_column(DateTime, nullable=False, default=datetime.utcnow())
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))
live = mapped_column(Boolean, default=False, comment="True if post is published")
content = mapped_column(Text, default="")

Expand All @@ -32,8 +34,13 @@ class Comment(Base):
id = mapped_column(Uuid, primary_key=True, default=uuid4())
post = mapped_column(Uuid, ForeignKey(Post.id), default=uuid4())
author = mapped_column(ForeignKey(User.id), nullable=False)
created = mapped_column(DateTime, nullable=False, default=datetime.utcnow())
created = mapped_column(DateTime, nullable=False, default=datetime.now(UTC))
live = mapped_column(Boolean, default=False)
content = mapped_column(Text, default="")

return Base.metadata


@pytest.fixture
def package_path():
return Path(os.path.dirname(os.path.realpath(__file__))) / "assets"
49 changes: 49 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typer.testing import CliRunner

from paracelsus.cli import app

runner = CliRunner()


def test_graph(package_path):
result = runner.invoke(
app, ["graph", "example.base:Base", "--import-module", "example.models", "--python-dir", str(package_path)]
)

assert result.exit_code == 0

assert "users {" in result.stdout
assert "posts {" in result.stdout
assert "comments {" in result.stdout

assert "users ||--o{ posts : author" in result.stdout
assert "posts ||--o{ comments : post" in result.stdout
assert "users ||--o{ comments : author" in result.stdout

assert "CHAR(32) author FK" in result.stdout
assert 'CHAR(32) post FK "nullable"' in result.stdout
assert "DATETIME created" in result.stdout


def test_inject(package_path):
result = runner.invoke(
app,
[
"inject",
str(package_path / "README.md"),
"example.base:Base",
"--import-module",
"example.models",
"--python-dir",
str(package_path),
"--check",
],
)

assert result.exit_code == 1


def test_version():
result = runner.invoke(app, ["version"])

assert result.exit_code == 0
Loading

0 comments on commit 5b3afcd

Please sign in to comment.