Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hba records simplification #157

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 20 additions & 56 deletions pgtoolkit/hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
import re
import sys
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, Any
Expand All @@ -91,17 +91,13 @@ class HBARecord:
"""Holds a HBA record composed of fields and a comment.

Common fields are accessible through attribute : ``conntype``,
``databases``, ``users``, ``address``, ``netmask``, ``method``.
``database``, ``user``, ``address``, ``netmask``, ``method``.
Auth-options fields are also accessible through attribute like ``map``,
``ldapserver``, etc.

``address`` and ``netmask`` fields are not always defined. If not,
accessing undefined attributes trigger an :exc:`AttributeError`.

``databases`` and ``users`` have a single value variant respectively
:attr:`database` and :attr:`user`, computed after the list representation
of the field.

.. automethod:: parse
.. automethod:: __init__
.. automethod:: __str__
Expand All @@ -113,8 +109,8 @@ class HBARecord:

COMMON_FIELDS = [
"conntype",
"databases",
"users",
"database",
"user",
"address",
"netmask",
"method",
Expand All @@ -138,7 +134,7 @@ def parse(cls, line: str) -> HBARecord:

"""
line = line.strip()
record_fields = ["conntype", "databases", "users"]
record_fields = ["conntype", "database", "user"]

# What the regexp below does is finding all elements separated by spaces
# unless they are enclosed in double-quotes
Expand All @@ -147,9 +143,7 @@ def parse(cls, line: str) -> HBARecord:
# double-quotes (alternative 1)
# \S = any non-whitespace character (alternative 2)
values = [p for p in re.findall(r"(?:\"+.*?\"+|\S)+", line) if p.strip()]
# Split databases and users lists.
values[1] = values[1].split(",")
values[2] = values[2].split(",")
assert len(values) > 2
try:
hash_pos = values.index("#")
except ValueError:
Expand All @@ -171,31 +165,19 @@ def parse(cls, line: str) -> HBARecord:
# Remove extra outer double quotes for auth options values if any
auth_options = [(o[0], re.sub(r"^\"|\"$", "", o[1])) for o in auth_options]
options = base_options + auth_options
return cls(options, comment=comment)
return cls(**{k: v for k, v in options}, comment=comment)

conntype: str | None
databases: list[str]
users: list[str]
database: str
user: str

def __init__(
self,
values: Iterable[tuple[str, str]] | dict[str, Any] | None = None,
comment: str | None = None,
**kw_values: str | Sequence[str],
) -> None:
def __init__(self, **values: Any) -> None:
"""
:param values: A dict of fields.
:param kw_values: Fields passed as keyword.
:param comment: Comment at the end of the line.
:param values: Fields passed as keyword.
"""
dict_values: dict[str, Any] = dict(values or {}, **kw_values)
if "database" in dict_values:
dict_values["databases"] = [dict_values.pop("database")]
if "user" in dict_values:
dict_values["users"] = [dict_values.pop("user")]
self.__dict__.update(dict_values)
self.fields = [k for k, _ in dict_values.items()]
self.comment = comment
self.__dict__.update(values)
self.comment = values.pop("comment", None)
pgiraud marked this conversation as resolved.
Show resolved Hide resolved
self.fields = values.keys()
pgiraud marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
return "<{} {}{}>".format(
Expand Down Expand Up @@ -224,9 +206,7 @@ def __str__(self) -> str:
fmt += "%%(%s)-%ds " % (field_, width - 1)
else:
fmt += f"%({field_})s "
# Serialize database and user list using property.
values = dict(self.__dict__, databases=self.database, users=self.user)
line = fmt.rstrip() % values
line = fmt.rstrip() % self.__dict__

auth_options = ['%s="%s"' % i for i in self.auth_options]
if auth_options:
Expand All @@ -242,17 +222,13 @@ def __str__(self) -> str:
def __eq__(self, other: object) -> bool:
return str(self) == str(other)

def as_dict(self, serialized: bool = False) -> dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
str_fields = self.COMMON_FIELDS[:]
if serialized:
str_fields[1:3] = ["database", "user"]
return {f: getattr(self, f) for f in str_fields if hasattr(self, f)}

@property
def common_values(self) -> list[str]:
str_fields = self.COMMON_FIELDS[:]
# Use serialized variant.
str_fields[1:3] = ["database", "user"]
return [getattr(self, f) for f in str_fields if f in self.fields]

@property
Expand All @@ -262,24 +238,12 @@ def auth_options(self) -> list[tuple[str, str]]:
]

@property
def database(self) -> str:
pgiraud marked this conversation as resolved.
Show resolved Hide resolved
"""Hold database column as a single value.

Use `databases` attribute to get parsed database list. `database` is
guaranteed to be a string.

"""
return ",".join(self.databases)
def databases(self) -> list[str]:
return self.database.split(",")

@property
def user(self) -> str:
"""Hold user column as a single value.

Use ``users`` property to get parsed user list. ``user`` is guaranteed
to be a string.

"""
return ",".join(self.users)
def users(self) -> list[str]:
return self.user.split(",")

def matches(self, **attrs: str) -> bool:
"""Tells if the current record is matching provided attributes.
Expand Down
40 changes: 14 additions & 26 deletions tests/test_hba.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test_comment():
def test_parse_host_line():
from pgtoolkit.hba import HBARecord

record = HBARecord.parse("host replication all ::1/128 trust")
record = HBARecord.parse("host replication,mydb all ::1/128 trust")
assert "host" in repr(record)
assert "host" == record.conntype
assert "replication" == record.database
assert ["replication"] == record.databases
assert "replication,mydb" == record.database
assert ["replication", "mydb"] == record.databases
assert "all" == record.user
assert ["all"] == record.users
assert "::1/128" == record.address
Expand All @@ -57,6 +57,8 @@ def test_parse_local_line():
record = HBARecord.parse("local all all trust")
assert "local" == record.conntype
assert "all" == record.database
assert ["all"] == record.databases
assert "all" == record.user
assert ["all"] == record.users
assert "trust" == record.method

Expand All @@ -77,7 +79,7 @@ def test_parse_auth_option():
)
assert "local" == record.conntype
assert "veryverylongdatabasenamethatdonotfit" == record.database
assert ["all"] == record.users
assert "all" == record.user
assert "ident" == record.method
assert "omicron" == record.map

Expand All @@ -97,7 +99,7 @@ def test_parse_record_with_comment():
record = HBARecord.parse("local all all trust # My comment")
assert "local" == record.conntype
assert "all" == record.database
assert ["all"] == record.users
assert "all" == record.user
assert "trust" == record.method
assert "My comment" == record.comment

Expand Down Expand Up @@ -186,7 +188,7 @@ def test_hba_create():
assert 2 == len(hba.lines)

r = hba.lines[1]
assert ["all"] == r.databases
assert "all" == r.database


def test_parse_file(mocker, tmp_path):
Expand Down Expand Up @@ -345,8 +347,8 @@ def r(hba):
other_hba = HBA()
record = HBARecord(
conntype="host",
databases=["replication"],
users=["all"],
database="replication",
user="all",
address="1.2.3.4",
method="trust",
)
Expand All @@ -365,12 +367,6 @@ def test_as_dict():
method="trust",
)
assert r.as_dict() == {
"conntype": "local",
"databases": ["all"],
"users": ["all"],
"method": "trust",
}
assert r.as_dict(serialized=True) == {
"conntype": "local",
"database": "all",
"user": "all",
Expand All @@ -379,21 +375,13 @@ def test_as_dict():

r = HBARecord(
conntype="local",
databases=["mydb", "mydb2"],
users=["bob", "alice"],
database="mydb,mydb2",
user="bob,alice",
address="127.0.0.1",
netmask="255.255.255.255",
method="trust",
)
assert r.as_dict() == {
"address": "127.0.0.1",
"conntype": "local",
"databases": ["mydb", "mydb2"],
"users": ["bob", "alice"],
"method": "trust",
"netmask": "255.255.255.255",
}
assert r.as_dict(serialized=True) == {
"address": "127.0.0.1",
"conntype": "local",
"database": "mydb,mydb2",
Expand All @@ -417,8 +405,8 @@ def test_hbarecord_equality():

r = HBARecord(
conntype="host",
databases=["all"],
users=["u0", "u1"],
database="all",
user="u0,u1",
address="127.0.0.1/32",
method="trust",
)
Expand Down
Loading