Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hba improvements #123

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions pgtoolkit/hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ def __str__(self) -> str:

return line

def __eq__(self, other: object) -> bool:
return str(self) == str(other)

pgiraud marked this conversation as resolved.
Show resolved Hide resolved
def as_dict(self, serialized: bool = False) -> dict[str, Any]:
str_fields = self.COMMON_FIELDS[:]
if serialized:
str_fields[1:3] = ["database", "user"]
return {f: getattr(self, f) for f in str_fields if hasattr(self, f)}

@property
def common_values(self) -> list[str]:
str_fields = self.COMMON_FIELDS[:]
Expand Down Expand Up @@ -307,7 +316,7 @@ class HBA:
"""

lines: list[Union[HBAComment, HBARecord]]
path: Optional[str]
path: Optional[Union[str, Path]]

def __init__(
self, entries: Optional[Iterable[Union[HBAComment, HBARecord]]] = None
Expand Down Expand Up @@ -365,7 +374,7 @@ def remove(
self,
filter: Optional[Callable[[HBARecord], bool]] = None,
**attrs: str,
) -> None:
) -> bool:
"""Remove records matching the provided attributes.

One can for example remove all records for which user is 'david'.
Expand All @@ -376,6 +385,8 @@ def remove(
:param attrs: keyword/values pairs correspond to one or more
HBARecord attributes (ie. user, conntype, etc...)

:returns: ``True`` if records have changed.

Usage examples:

.. code:: python
Expand All @@ -393,20 +404,26 @@ def remove(

filter = filter or (lambda line: line.matches(**attrs))

lines_before = self.lines

self.lines = [
line
for line in self.lines
if not (isinstance(line, HBARecord) and filter(line))
]

def merge(self, other: "HBA") -> None:
return lines_before != self.lines

def merge(self, other: "HBA") -> bool:
"""Add new records to HBAFile or replace them if they are matching
(ie. same conntype, database, user and address)

:param other: HBAFile to merge into the current one.
Lines with matching conntype, database, user and database will be
replaced by the new one. Otherwise they will be added at the end.
Comments from the original hba are preserved.

:returns: ``True`` if records have changed.
"""
lines = self.lines[:]
new_lines = other.lines[:]
Expand Down Expand Up @@ -435,15 +452,17 @@ def merge(self, other: "HBA") -> None:
# Then add remaining new lines (not merged)
self.lines.extend(new_lines)

return lines != self.lines


def parse(file: Union[str, Iterable[str]]) -> HBA:
def parse(file: Union[str, Iterable[str], Path]) -> HBA:
"""Parse a `pg_hba.conf` file.

:param file: Either a line iterator such as a file-like object or a string
:param file: Either a line iterator such as a file-like object, a path or a string
corresponding to the path to the file to open and parse.
:rtype: :class:`HBA`.
"""
if isinstance(file, str):
if isinstance(file, (str, Path)):
with open(file) as fo:
hba = parse(fo)
hba.path = file
Expand Down
118 changes: 110 additions & 8 deletions tests/test_hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_hba_create():
HBA("blah")


def test_parse_file(mocker):
def test_parse_file(mocker, tmp_path):
from pgtoolkit.hba import HBAComment, parse

m = mocker.mock_open()
Expand All @@ -201,6 +201,13 @@ def test_parse_file(mocker):
pgpass.lines.append(HBAComment("# Something"))
assert m.called

# Also works with path
m.reset_mock()
pgpass = parse(tmp_path / "filename")
pgpass.lines.append(HBAComment("# Something"))

assert m.called


def test_hba_error(mocker):
from pgtoolkit.hba import ParseError, parse
Expand All @@ -224,32 +231,40 @@ def test_remove():
with pytest.raises(ValueError):
hba.remove()

hba.remove(database="replication")
result = hba.remove(database="badname")
assert not result

result = hba.remove(database="replication")
assert result
entries = list(iter(hba))
assert 4 == len(entries)

hba = parse(lines)
hba.remove(filter=lambda r: r.database == "replication")
result = hba.remove(filter=lambda r: r.database == "replication")
assert result
entries = list(iter(hba))
assert 4 == len(entries)

hba = parse(lines)
hba.remove(conntype="host", database="replication")
result = hba.remove(conntype="host", database="replication")
assert result
entries = list(iter(hba))
assert 5 == len(entries)

# Works even for fields that may not be valid for all records
# `address` is not valid for `local` connection type
hba = parse(lines)
hba.remove(address="127.0.0.1/32")
result = hba.remove(address="127.0.0.1/32")
assert result
entries = list(iter(hba))
assert 6 == len(entries)

def filter(r):
return r.conntype == "host" and r.database == "replication"

hba = parse(lines)
hba.remove(filter=filter)
result = hba.remove(filter=filter)
assert result
entries = list(iter(hba))
assert 5 == len(entries)

Expand All @@ -269,7 +284,7 @@ def filter(r):
def test_merge():
import os

from pgtoolkit.hba import parse
from pgtoolkit.hba import HBA, HBARecord, parse

sample = """\
# comment
Expand All @@ -292,7 +307,8 @@ def test_merge():
"""
other_lines = other_sample.splitlines(True)
other_hba = parse(other_lines)
hba.merge(other_hba)
result = hba.merge(other_hba)
assert result

expected_sample = """\
# comment
Expand All @@ -315,3 +331,89 @@ def r(hba):
return os.linesep.join([str(line) for line in hba.lines])

assert r(hba) == r(expected_hba)

other_hba = HBA()
record = HBARecord(
conntype="host",
databases=["replication"],
users=["all"],
address="1.2.3.4",
method="trust",
)
other_hba.lines.append(record)
result = hba.merge(other_hba)
assert not result


def test_as_dict():
from pgtoolkit.hba import HBARecord

r = HBARecord(
conntype="local",
database="all",
user="all",
method="trust",
)
assert r.as_dict() == {
"conntype": "local",
"databases": ["all"],
"users": ["all"],
"method": "trust",
}
assert r.as_dict(serialized=True) == {
"conntype": "local",
"database": "all",
"user": "all",
"method": "trust",
}

r = HBARecord(
conntype="local",
databases=["mydb", "mydb2"],
users=["bob", "alice"],
address="127.0.0.1",
netmask="255.255.255.255",
method="trust",
)
assert r.as_dict() == {
"address": "127.0.0.1",
"conntype": "local",
"databases": ["mydb", "mydb2"],
"users": ["bob", "alice"],
"method": "trust",
"netmask": "255.255.255.255",
}
assert r.as_dict(serialized=True) == {
"address": "127.0.0.1",
"conntype": "local",
"database": "mydb,mydb2",
"user": "bob,alice",
"method": "trust",
"netmask": "255.255.255.255",
}


def test_hbarecord_equality():
from pgtoolkit.hba import HBARecord

r = HBARecord(
conntype="local",
database="all",
user="all",
method="trust",
)
r2 = HBARecord.parse("local all all trust")
assert r == r2

r = HBARecord(
conntype="host",
databases=["all"],
users=["u0", "u1"],
address="127.0.0.1/32",
method="trust",
)
r2 = HBARecord.parse("host all u0,u1 127.0.0.1/32 trust")
assert r == r2

r2 = HBARecord.parse("host mydb u0,u1 127.0.0.1/32 trust")
assert r != r2
Loading