Skip to content

Commit

Permalink
Tell whether changes occurred in HBA::merge
Browse files Browse the repository at this point in the history
  • Loading branch information
pgiraud committed Sep 17, 2024
1 parent 6d232d7 commit c0721d4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
9 changes: 8 additions & 1 deletion pgtoolkit/hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def __str__(self) -> str:

return line

def __eq__(self, other: object) -> bool:
return str(self) == str(other)

def as_dict(self, serialized=False) -> dict[str, Any]:
str_fields = self.COMMON_FIELDS[:]
if serialized:
Expand Down Expand Up @@ -411,14 +414,16 @@ def remove(

return lines_before != self.lines

def merge(self, other: "HBA") -> None:
def merge(self, other: "HBA") -> bool:
"""Add new records to HBAFile or replace them if they are matching
(ie. same conntype, database, user and address)
:param other: HBAFile to merge into the current one.
Lines with matching conntype, database, user and database will be
replaced by the new one. Otherwise they will be added at the end.
Comments from the original hba are preserved.
:returns ``True`` if records have changed.
"""
lines = self.lines[:]
new_lines = other.lines[:]
Expand Down Expand Up @@ -447,6 +452,8 @@ def merge(self, other: "HBA") -> None:
# Then add remaining new lines (not merged)
self.lines.extend(new_lines)

return lines != self.lines


def parse(file: Union[str, Iterable[str]]) -> HBA:
"""Parse a `pg_hba.conf` file.
Expand Down
17 changes: 15 additions & 2 deletions tests/test_hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def filter(r):
def test_merge():
import os

from pgtoolkit.hba import parse
from pgtoolkit.hba import HBA, HBARecord, parse

sample = """\
# comment
Expand All @@ -300,7 +300,8 @@ def test_merge():
"""
other_lines = other_sample.splitlines(True)
other_hba = parse(other_lines)
hba.merge(other_hba)
result = hba.merge(other_hba)
assert result

expected_sample = """\
# comment
Expand All @@ -324,6 +325,18 @@ def r(hba):

assert r(hba) == r(expected_hba)

other_hba = HBA()
record = HBARecord(
conntype="host",
databases=["replication"],
users=["all"],
address="1.2.3.4",
method="trust",
)
other_hba.lines.append(record)
result = hba.merge(other_hba)
assert not result


def test_as_dict():
from pgtoolkit.hba import HBARecord
Expand Down

0 comments on commit c0721d4

Please sign in to comment.