diff --git a/pgtoolkit/hba.py b/pgtoolkit/hba.py index 85a69d0..42d7734 100644 --- a/pgtoolkit/hba.py +++ b/pgtoolkit/hba.py @@ -73,7 +73,7 @@ import re import sys import warnings -from collections.abc import Callable, Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field from pathlib import Path from typing import IO, Any @@ -91,17 +91,13 @@ class HBARecord: """Holds a HBA record composed of fields and a comment. Common fields are accessible through attributeĀ : ``conntype``, - ``databases``, ``users``, ``address``, ``netmask``, ``method``. + ``database``, ``user``, ``address``, ``netmask``, ``method``. Auth-options fields are also accessible through attribute like ``map``, ``ldapserver``, etc. ``address`` and ``netmask`` fields are not always defined. If not, accessing undefined attributes trigger an :exc:`AttributeError`. - ``databases`` and ``users`` have a single value variant respectively - :attr:`database` and :attr:`user`, computed after the list representation - of the field. - .. automethod:: parse .. automethod:: __init__ .. automethod:: __str__ @@ -113,8 +109,8 @@ class HBARecord: COMMON_FIELDS = [ "conntype", - "databases", - "users", + "database", + "user", "address", "netmask", "method", @@ -138,7 +134,7 @@ def parse(cls, line: str) -> HBARecord: """ line = line.strip() - record_fields = ["conntype", "databases", "users"] + record_fields = ["conntype", "database", "user"] # What the regexp below does is finding all elements separated by spaces # unless they are enclosed in double-quotes @@ -147,9 +143,7 @@ def parse(cls, line: str) -> HBARecord: # double-quotes (alternative 1) # \S = any non-whitespace character (alternative 2) values = [p for p in re.findall(r"(?:\"+.*?\"+|\S)+", line) if p.strip()] - # Split databases and users lists. - values[1] = values[1].split(",") - values[2] = values[2].split(",") + assert len(values) > 2 try: hash_pos = values.index("#") except ValueError: @@ -171,31 +165,19 @@ def parse(cls, line: str) -> HBARecord: # Remove extra outer double quotes for auth options values if any auth_options = [(o[0], re.sub(r"^\"|\"$", "", o[1])) for o in auth_options] options = base_options + auth_options - return cls(options, comment=comment) + return cls(**{k: v for k, v in options}, comment=comment) conntype: str | None - databases: list[str] - users: list[str] + database: str + user: str - def __init__( - self, - values: Iterable[tuple[str, str]] | dict[str, Any] | None = None, - comment: str | None = None, - **kw_values: str | Sequence[str], - ) -> None: + def __init__(self, **values: Any) -> None: """ - :param values: A dict of fields. - :param kw_values: Fields passed as keyword. - :param comment: Comment at the end of the line. + :param values: Fields passed as keyword. """ - dict_values: dict[str, Any] = dict(values or {}, **kw_values) - if "database" in dict_values: - dict_values["databases"] = [dict_values.pop("database")] - if "user" in dict_values: - dict_values["users"] = [dict_values.pop("user")] - self.__dict__.update(dict_values) - self.fields = [k for k, _ in dict_values.items()] - self.comment = comment + self.__dict__.update(values) + self.comment = values.pop("comment", None) + self.fields = values.keys() def __repr__(self) -> str: return "<{} {}{}>".format( @@ -224,9 +206,7 @@ def __str__(self) -> str: fmt += "%%(%s)-%ds " % (field_, width - 1) else: fmt += f"%({field_})s " - # Serialize database and user list using property. - values = dict(self.__dict__, databases=self.database, users=self.user) - line = fmt.rstrip() % values + line = fmt.rstrip() % self.__dict__ auth_options = ['%s="%s"' % i for i in self.auth_options] if auth_options: @@ -242,17 +222,13 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: return str(self) == str(other) - def as_dict(self, serialized: bool = False) -> dict[str, Any]: + def as_dict(self) -> dict[str, Any]: str_fields = self.COMMON_FIELDS[:] - if serialized: - str_fields[1:3] = ["database", "user"] return {f: getattr(self, f) for f in str_fields if hasattr(self, f)} @property def common_values(self) -> list[str]: str_fields = self.COMMON_FIELDS[:] - # Use serialized variant. - str_fields[1:3] = ["database", "user"] return [getattr(self, f) for f in str_fields if f in self.fields] @property @@ -262,24 +238,12 @@ def auth_options(self) -> list[tuple[str, str]]: ] @property - def database(self) -> str: - """Hold database column as a single value. - - Use `databases` attribute to get parsed database list. `database` is - guaranteed to be a string. - - """ - return ",".join(self.databases) + def databases(self) -> list[str]: + return self.database.split(",") @property - def user(self) -> str: - """Hold user column as a single value. - - Use ``users`` property to get parsed user list. ``user`` is guaranteed - to be a string. - - """ - return ",".join(self.users) + def users(self) -> list[str]: + return self.user.split(",") def matches(self, **attrs: str) -> bool: """Tells if the current record is matching provided attributes. diff --git a/tests/test_hba.py b/tests/test_hba.py index 1f74fb0..83ac204 100644 --- a/tests/test_hba.py +++ b/tests/test_hba.py @@ -36,11 +36,11 @@ def test_comment(): def test_parse_host_line(): from pgtoolkit.hba import HBARecord - record = HBARecord.parse("host replication all ::1/128 trust") + record = HBARecord.parse("host replication,mydb all ::1/128 trust") assert "host" in repr(record) assert "host" == record.conntype - assert "replication" == record.database - assert ["replication"] == record.databases + assert "replication,mydb" == record.database + assert ["replication", "mydb"] == record.databases assert "all" == record.user assert ["all"] == record.users assert "::1/128" == record.address @@ -57,6 +57,8 @@ def test_parse_local_line(): record = HBARecord.parse("local all all trust") assert "local" == record.conntype assert "all" == record.database + assert ["all"] == record.databases + assert "all" == record.user assert ["all"] == record.users assert "trust" == record.method @@ -77,7 +79,7 @@ def test_parse_auth_option(): ) assert "local" == record.conntype assert "veryverylongdatabasenamethatdonotfit" == record.database - assert ["all"] == record.users + assert "all" == record.user assert "ident" == record.method assert "omicron" == record.map @@ -97,7 +99,7 @@ def test_parse_record_with_comment(): record = HBARecord.parse("local all all trust # My comment") assert "local" == record.conntype assert "all" == record.database - assert ["all"] == record.users + assert "all" == record.user assert "trust" == record.method assert "My comment" == record.comment @@ -186,7 +188,7 @@ def test_hba_create(): assert 2 == len(hba.lines) r = hba.lines[1] - assert ["all"] == r.databases + assert "all" == r.database def test_parse_file(mocker, tmp_path): @@ -345,8 +347,8 @@ def r(hba): other_hba = HBA() record = HBARecord( conntype="host", - databases=["replication"], - users=["all"], + database="replication", + user="all", address="1.2.3.4", method="trust", ) @@ -365,12 +367,6 @@ def test_as_dict(): method="trust", ) assert r.as_dict() == { - "conntype": "local", - "databases": ["all"], - "users": ["all"], - "method": "trust", - } - assert r.as_dict(serialized=True) == { "conntype": "local", "database": "all", "user": "all", @@ -379,21 +375,13 @@ def test_as_dict(): r = HBARecord( conntype="local", - databases=["mydb", "mydb2"], - users=["bob", "alice"], + database="mydb,mydb2", + user="bob,alice", address="127.0.0.1", netmask="255.255.255.255", method="trust", ) assert r.as_dict() == { - "address": "127.0.0.1", - "conntype": "local", - "databases": ["mydb", "mydb2"], - "users": ["bob", "alice"], - "method": "trust", - "netmask": "255.255.255.255", - } - assert r.as_dict(serialized=True) == { "address": "127.0.0.1", "conntype": "local", "database": "mydb,mydb2", @@ -417,8 +405,8 @@ def test_hbarecord_equality(): r = HBARecord( conntype="host", - databases=["all"], - users=["u0", "u1"], + database="all", + user="u0,u1", address="127.0.0.1/32", method="trust", )