diff --git a/patroni/api.py b/patroni/api.py index e8bff229c..b44630ab4 100644 --- a/patroni/api.py +++ b/patroni/api.py @@ -671,13 +671,9 @@ def do_GET_metrics(self) -> None: metrics.append("# TYPE patroni_is_paused gauge") metrics.append("patroni_is_paused{0} {1}".format(labels, int(postgres.get('pause', 0)))) - if patroni.multisite.is_active: - metrics.append("# HELP patroni_multisite_switches Number of times multisite leader has been switched") - metrics.append("# TYPE patroni_multisite_switches counter") - metrics.append("patroni_multisite_switches{0} {1}" - .format(labels, patroni.multisite.site_switches)) + patroni.multisite.append_metrics(metrics, labels) - self.write_response(200, '\n'.join(metrics)+'\n', content_type='text/plain') + self.write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain') def do_GET_multisite(self): self._write_json_response(200, self.server.patroni.multisite.status()) @@ -1199,7 +1195,7 @@ def do_POST_multisite_switchover(self): if not request: return if not self.server.patroni.multisite.is_active: - return self._write_response(400, 'Cluster is not in multisite mode') + return self.write_response(400, 'Cluster is not in multisite mode') scheduled_at = request.get('scheduled_at') target_site = request.get('target_site') diff --git a/patroni/ctl.py b/patroni/ctl.py index f0ae72446..51da91c09 100644 --- a/patroni/ctl.py +++ b/patroni/ctl.py @@ -1502,8 +1502,6 @@ def get_cluster_service_info(cluster: Dict[str, Any]) -> List[str]: """ service_info: List[str] = [] - - if 'multisite' in cluster: info = f"Multisite {cluster['multisite']['name'] or ''} is {cluster['multisite']['status'].lower()}" standby_config = cluster['multisite'].get('standby_config', {}) diff --git a/patroni/dcs/__init__.py b/patroni/dcs/__init__.py index b8f9bcd88..ea0050105 100644 --- a/patroni/dcs/__init__.py +++ b/patroni/dcs/__init__.py @@ -464,7 +464,7 @@ class Failover(NamedTuple): leader: Optional[str] candidate: Optional[str] scheduled_at: Optional[datetime.datetime] - target_site: Optional[str] + target_site: Optional[str] = None @staticmethod def from_node(version: _Version, value: Union[str, Dict[str, str]]) -> 'Failover': @@ -496,7 +496,8 @@ def from_node(version: _Version, value: Union[str, Dict[str, str]]) -> 'Failover if data.get('scheduled_at'): data['scheduled_at'] = dateutil.parser.parse(data['scheduled_at']) - return Failover(version, data.get('leader'), data.get('member'), data.get('scheduled_at'), data.get('target_site')) + return Failover(version, data.get('leader'), data.get('member'), data.get('scheduled_at'), + data.get('target_site')) def __len__(self) -> int: """Implement ``len`` function capability. diff --git a/patroni/dcs/etcd3.py b/patroni/dcs/etcd3.py index 406af49f2..5fb48bdc2 100644 --- a/patroni/dcs/etcd3.py +++ b/patroni/dcs/etcd3.py @@ -754,7 +754,8 @@ def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster: # get leader leader = nodes.get(self._LEADER) - if not self._ctl and not self._multisite and leader and leader['value'] == self._name and self._lease != leader.get('lease'): + if not self._ctl and not self._multisite and leader and leader['value'] == self._name \ + and self._lease != leader.get('lease'): logger.warning('I am the leader but not owner of the lease') if leader: diff --git a/patroni/ha.py b/patroni/ha.py index 3b212572f..21763bcee 100644 --- a/patroni/ha.py +++ b/patroni/ha.py @@ -360,8 +360,8 @@ def acquire_lock(self) -> bool: self.set_is_leader(ret) multisite_ret = self.patroni.multisite.resolve_leader() if multisite_ret: - logger.error("Releasing leader lock because multi site status is: "+multisite_ret) - self.dcs.delete_leader() + logger.error("Releasing leader lock because multi site status is: %s", multisite_ret) + self.dcs.delete_leader(None, None) return False return ret @@ -1555,7 +1555,7 @@ def demote(self, mode: str) -> Optional[bool]: 'graceful': dict(stop='fast', checkpoint=True, release=True, offline=False, async_req=False), # noqa: E241,E501 'immediate': dict(stop='immediate', checkpoint=False, release=True, offline=False, async_req=True), # noqa: E241,E501 'immediate-nolock': dict(stop='immediate', checkpoint=False, release=False, offline=False, async_req=True), # noqa: E241,E501 - 'multisite': dict(stop='fast', checkpoint=True, release=False, offline=True, async_req=False), # noqa: E241,E501 + 'multisite': dict(stop='fast', checkpoint=True, release=False, offline=True, async_req=False), # noqa: E241,E501 }[mode] logger.info('Demoting self (%s)', mode) @@ -1577,7 +1577,7 @@ def on_shutdown(checkpoint_location: int, prev_location: int) -> None: status['released'] = True if mode == 'multisite': - on_shutdown = self.patroni.multisite.on_shutdown + on_shutdown = self.patroni.multisite.on_shutdown # noqa: F811 def before_shutdown() -> None: if self.state_handler.mpp_handler.is_coordinator(): @@ -1599,6 +1599,7 @@ def before_shutdown() -> None: with self._async_executor: self.release_leader_key_voluntarily(checkpoint_location) time.sleep(2) # Give a time to somebody to take the leader lock + # FIXME: multisite.on_shutdown() was already called above with _state_handler.stop(), do we really need it here? if mode == 'multisite': self.patroni.multisite.on_shutdown(self.state_handler.latest_checkpoint_location()) if mode_control['offline']: @@ -1712,7 +1713,8 @@ def process_unhealthy_cluster(self) -> str: if failover: if self.is_paused() and failover.leader and failover.candidate: logger.info('Updating failover key after acquiring leader lock...') - self.dcs.manual_failover('', failover.candidate, failover.scheduled_at, version=failover.version) + self.dcs.manual_failover('', failover.candidate, failover.scheduled_at, + version=failover.version) else: logger.info('Cleaning up failover key after acquiring leader lock...') self.dcs.manual_failover('', '') diff --git a/patroni/multisite.py b/patroni/multisite.py index 523a47dd3..3d64b128c 100644 --- a/patroni/multisite.py +++ b/patroni/multisite.py @@ -1,70 +1,88 @@ import abc import json import logging -from datetime import datetime -from threading import Thread, Event import time + +from collections.abc import Callable +from datetime import datetime, UTC +from threading import Event, Thread +from typing import Any, Dict, List, Optional, TYPE_CHECKING + import six -from .dcs import Member, Cluster -from .dcs.kubernetes import catch_kubernetes_errors, Kubernetes +import kubernetes + +from .dcs import AbstractDCS, Cluster, Member +from .dcs.kubernetes import catch_kubernetes_errors from .exceptions import DCSError -import kubernetes +if TYPE_CHECKING: # pragma: no cover + from .config import Config + from .dcs import Cluster logger = logging.getLogger(__name__) + @six.add_metaclass(abc.ABCMeta) class AbstractSiteController(object): # Set whether we are relying on this controller for providing standby config is_active = False + dcs: AbstractDCS + _has_leader: bool + def start(self): pass def shutdown(self): pass - def get_active_standby_config(self): + def get_active_standby_config(self) -> Dict[str, Any]: """Returns currently active configuration for standby leader""" + return {} - def is_leader_site(self): - return self.get_active_standby_config() is None + def is_leader_site(self) -> bool: + return self.get_active_standby_config() == {} - def resolve_leader(self): + def resolve_leader(self) -> Optional[str]: """Try to become leader, update active config correspondingly. - Return error when unable to resolve""" + Return error when unable to resolve leader status.""" return None def heartbeat(self): """"Notify multisite mechanism that this site has a properly operating cluster mechanism. - Need to send out an async lease update. If that fails to complete within safety margin of ttl running + Needs to send out an async lease update. If that fails to complete within safety margin of ttl running out then we need to update site config """ def release(self): pass - def status(self): - pass + def status(self) -> Dict[str, Any]: + return {} - def should_failover(self): + def should_failover(self) -> bool: return False - def on_shutdown(self, checkpoint_location): + def on_shutdown(self, checkpoint_location: int, prev_location: int): pass + def append_metrics(self, metrics: List[str], labels: str) -> None: + pass + + class SingleSiteController(AbstractSiteController): """Do nothing controller for single site operation.""" def status(self): return {"status": "Leader", "active": False} + class MultisiteController(Thread, AbstractSiteController): is_active = True - def __init__(self, config, on_change=None): + def __init__(self, config: 'Config', on_change: Callable[[], None]): super().__init__() self.stop_requested = False self.on_change = on_change @@ -96,7 +114,7 @@ def __init__(self, config, on_change=None): if msconfig.get('update_crd'): self._state_updater = KubernetesStateManagement(msconfig.get('update_crd'), msconfig.get('crd_uid'), - reporter=self.name, # Use pod name? + reporter=self.name, # Use pod name? crd_api=msconfig.get('crd_api', 'acid.zalan.do/v1')) else: self._state_updater = None @@ -104,13 +122,13 @@ def __init__(self, config, on_change=None): self.switchover_timeout = msconfig.get('switchover_timeout', 300) self._heartbeat = Event() - self._standby_config = None + self._standby_config = {} self._leader_resolved = Event() self._has_leader = False self._release = False self._status = None self._failover_target = None - self._failover_timeout = None + self._failover_timeout = 0 self.site_switches = None @@ -118,7 +136,7 @@ def __init__(self, config, on_change=None): def status(self): return { - "status": "Leader" if self._has_leader or self._standby_config is None else "Standby", + "status": "Leader" if self._has_leader or self._standby_config == {} else "Standby", "active": True, "name": self.name, "standby_config": self.get_active_standby_config(), @@ -130,22 +148,22 @@ def get_active_standby_config(self): def resolve_leader(self): """Try to become leader, update active config correspondingly. - Must be called from Patroni main thread. After a successful return get_active_standby_config() will - return a value corresponding to a multisite status that was active after start of the call. + Must be called from Patroni main thread. After a successful return :func:`get_active_standby_config()` will + return a value corresponding to the multisite status that was active after the start of the call. - Returns error message encountered when unable to resolve leader status.""" + Returns the error message encountered when unable to resolve leader status.""" self._leader_resolved.clear() self._heartbeat.set() self._leader_resolved.wait() return self._dcs_error def heartbeat(self): - """Notify multisite mechanism that this site has a properly operating cluster mechanism. + """Notify the multisite mechanism that this site has a leader with a properly operating HA cycle. - Need to send out an async lease update. If that fails to complete within safety margin of ttl running - out then we need to demote. + Needs to send out an async lease update. If that fails to complete within the safety margin of ``ttl``running + out, then we need to demote. """ - logger.info("Triggering multisite hearbeat") + logger.info("Triggering multisite heartbeat") self._heartbeat.set() def release(self): @@ -155,7 +173,7 @@ def release(self): def should_failover(self): return self._failover_target is not None and self._failover_target != self.name - def on_shutdown(self, checkpoint_location): + def on_shutdown(self, checkpoint_location: int, prev_location: int): """ Called when shutdown for multisite failover has completed. """ # TODO: check if we replicated everything to standby site @@ -165,7 +183,7 @@ def _disconnected_operation(self): self._standby_config = {'restore_command': 'false'} def _set_standby_config(self, other: Member): - logger.info(f"Multisite replicate from {other}") + logger.info(f"We will replicate from {other} in a multisite setup") # TODO: add support for replication slots try: old_conf, self._standby_config = self._standby_config, { @@ -181,19 +199,17 @@ def _set_standby_config(self, other: Member): logger.info(f"Setting standby configuration to: {self._standby_config}") return old_conf != self._standby_config - def _check_transition(self, leader, note=None): + def _check_transition(self, leader: bool, note: str): if self._has_leader != leader: - logger.info("State transition") + logger.info("Multisite state transition") self._has_leader = leader - if self.on_change: - self.on_change() + self.on_change() if self._state_updater and self._status != leader: self._state_updater.state_transition('Leader' if leader else 'Standby', note) self._status = leader - def _resolve_multisite_leader(self): - logger.info("Running multisite consensus.") + logger.info("Running multisite consensus") try: # Refresh the latest known state cluster = self.dcs.get_cluster() @@ -214,7 +230,7 @@ def _resolve_multisite_leader(self): # Became leader of unlocked cluster if self.dcs.attempt_to_acquire_leader(): logger.info("Became multisite leader") - self._standby_config = None + self._standby_config = {} self._check_transition(leader=True, note="Acquired multisite leader status") if cluster.failover and cluster.failover.target_site and cluster.failover.target_site == self.name: logger.info("Cleaning up multisite failover key after acquiring leader status") @@ -234,18 +250,18 @@ def _resolve_multisite_leader(self): lock_owner = cluster.leader and cluster.leader.name # The leader is us if lock_owner == self.name: - logger.info("Multisite has leader and it is us") + logger.info("Multisite has a leader and it is us") if self._release: logger.info("Releasing multisite leader status") self.dcs.delete_leader(cluster.leader) self._release = False self._disconnected_operation() - self._check_transition(leader=False, note="Released multisite leader status on request") + self._check_transition(leader=False, note="Released multisite leader status upon a request") return if self.dcs.update_leader(cluster, None): logger.info("Updated multisite leader lease") # Make sure we are disabled from standby mode - self._standby_config = None + self._standby_config = {} self._check_transition(leader=True, note="Already have multisite leader status") self._check_for_failover(cluster) else: @@ -254,23 +270,24 @@ def _resolve_multisite_leader(self): self._check_transition(leader=False, note="Failed to update multisite leader status") # Current leader is someone else else: - logger.info(f"Multisite has leader and it is {lock_owner}") + logger.info(f"Multisite has a leader and it is {lock_owner}") self._release = False # Failover successful or someone else took over if self._failover_target is not None: self._failover_target = None - self._failover_timeout = None - if self._set_standby_config(cluster.leader.member): + self._failover_timeout = 0 + if cluster.leader and self._set_standby_config(cluster.leader.member): # Wake up anyway to notice that we need to replicate from new leader. For the other case # _check_transition() handles the wake. if not self._has_leader: self.on_change() - note = f"Lost leader lock to {lock_owner}" if self._has_leader else f"Current leader {lock_owner}" + note = f"Lost leader lock to {lock_owner}" if self._has_leader \ + else f"Current leader is {lock_owner}" self._check_transition(leader=False, note=note) except DCSError as e: logger.error(f"Error accessing multisite DCS: {e}") - self._dcs_error = 'Multi site DCS cannot be reached' + self._dcs_error = 'Multisite DCS cannot be reached' if self._has_leader: self._disconnected_operation() self._has_leader = False @@ -281,12 +298,12 @@ def _resolve_multisite_leader(self): try: self._update_history(cluster) self.touch_member() - except DCSError as e: + except DCSError: pass def _observe_leader(self): """ - Observe multisite state and make sure + Observe multisite state and make sure it is reflected correctly by standby config. """ try: @@ -301,15 +318,15 @@ def _observe_leader(self): # The leader is us if lock_owner == self.name: logger.info("Multisite leader is us") - self._standby_config = None + self._standby_config = {} else: logger.info(f"Multisite leader is {lock_owner}") - self._set_standby_config(cluster.leader.member) + self._set_standby_config(cluster.leader.member) # pyright: ignore except DCSError as e: # On replicas we need to know the multisite status only for rewinding. logger.warning(f"Error accessing multisite DCS: {e}") - def _update_history(self, cluster): + def _update_history(self, cluster: 'Cluster'): if cluster.history and cluster.history.lines and isinstance(cluster.history.lines[0], dict): self.site_switches = cluster.history.lines[0].get('switches') @@ -337,7 +354,7 @@ def _check_for_failover(self, cluster: Cluster): self._failover_target = cluster.failover.target_site else: self._failover_target = None - self._failover_timeout = None + self._failover_timeout = 0 def touch_member(self): data = { @@ -350,7 +367,8 @@ def touch_member(self): def run(self): self._observe_leader() while not self._heartbeat.wait(self.config['observe_interval']): - # Keep track of who is the leader even when we are not the primary node to be able to rewind from them + # Keep track of who the leader is, even when we are not the primary node. + # Needed to be able to rewind from the leader. self._observe_leader() while not self.stop_requested: self._resolve_multisite_leader() @@ -366,16 +384,21 @@ def shutdown(self): self._heartbeat.set() self.join() + def append_metrics(self, metrics: List[str], labels: str): + metrics.append("# HELP patroni_multisite_switches Number of times multisite leader has been switched") + metrics.append("# TYPE patroni_multisite_switches counter") + metrics.append("patroni_multisite_switches{0} {1}".format(labels, self.site_switches)) + class KubernetesStateManagement: - def __init__(self, crd_name, crd_uid, reporter, crd_api): + def __init__(self, crd_name: str, crd_uid: str, reporter: str, crd_api: str): self.crd_namespace, self.crd_name = (['default'] + crd_name.rsplit('.', 1))[-2:] self.crd_uid = crd_uid self.reporter = reporter self.crd_api_group, self.crd_api_version = crd_api.rsplit('/', 1) # TODO: handle config loading when main DCS is not Kubernetes based - #apiclient = k8s_client.ApiClient(False) + # apiclient = k8s_client.ApiClient(False) kubernetes.config.load_incluster_config() apiclient = kubernetes.client.ApiClient() self._customobj_api = kubernetes.client.CustomObjectsApi(apiclient) @@ -384,13 +407,15 @@ def __init__(self, crd_name, crd_uid, reporter, crd_api): self._status_update = None self._event_obj = None - def state_transition(self, new_state, note): + def state_transition(self, new_state: str, note: str): self._status_update = {"status": {"Multisite": new_state}} - failover_time = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ") + failover_time = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%fZ") reason = 'Promote' if new_state == 'Leader' else 'Demote' - if note is None: - note = 'Acquired multisite leader' if new_state == 'Leader' else 'Became a standby cluster' + + # TODO: check if this is needed, no current call comes without note (this is already reflected in the signature) + # if note is None: + # note = 'Acquired multisite leader' if new_state == 'Leader' else 'Became a standby cluster' self._event_obj = kubernetes.client.EventsV1Event( action='Failover', @@ -422,9 +447,10 @@ def store_updates(self): @catch_kubernetes_errors def update_crd_state(self, update): - self._customobj_api.patch_namespaced_custom_object_status(self.crd_api_group, self.crd_api_version, self.crd_namespace, - 'postgresqls', self.crd_name + '/status', update, - field_manager='patroni') + self._customobj_api.patch_namespaced_custom_object_status(self.crd_api_group, self.crd_api_version, + self.crd_namespace, 'postgresqls', + self.crd_name + '/status', update, + field_manager='patroni') return True diff --git a/patroni/postgresql/citus.py b/patroni/postgresql/citus.py deleted file mode 100644 index 3a56d6bdb..000000000 --- a/patroni/postgresql/citus.py +++ /dev/null @@ -1,420 +0,0 @@ -import logging -import re -import time - -from threading import Condition, Event, Thread -from urllib.parse import urlparse -from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING - -from ..dcs import CITUS_COORDINATOR_GROUP_ID, Cluster -from ..psycopg import connect, quote_ident, ProgrammingError - -if TYPE_CHECKING: # pragma: no cover - from . import Postgresql - -CITUS_SLOT_NAME_RE = re.compile(r'^citus_shard_(move|split)_slot(_[1-9][0-9]*){2,3}$') -logger = logging.getLogger(__name__) - - -class PgDistNode(object): - """Represents a single row in the `pg_dist_node` table""" - - def __init__(self, group: int, host: str, port: int, event: str, nodeid: Optional[int] = None, - timeout: Optional[float] = None, cooldown: Optional[float] = None) -> None: - self.group = group - # A weird way of pausing client connections by adding the `-demoted` suffix to the hostname - self.host = host + ('-demoted' if event == 'before_demote' else '') - self.port = port - # Event that is trying to change or changed the given row. - # Possible values: before_demote, before_promote, after_promote. - self.event = event - self.nodeid = nodeid - - # If transaction was started, we need to COMMIT/ROLLBACK before the deadline - self.timeout = timeout - self.cooldown = cooldown or 10000 # 10s by default - self.deadline: float = 0 - - # All changes in the pg_dist_node are serialized on the Patroni - # side by performing them from a thread. The thread, that is - # requested a change, sometimes needs to wait for a result. - # For example, we want to pause client connections before demoting - # the worker, and once it is done notify the calling thread. - self._event = Event() - - def wait(self) -> None: - self._event.wait() - - def wakeup(self) -> None: - self._event.set() - - def __eq__(self, other: Any) -> bool: - return isinstance(other, PgDistNode) and self.event == other.event\ - and self.host == other.host and self.port == other.port - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __str__(self) -> str: - return ('PgDistNode(nodeid={0},group={1},host={2},port={3},event={4})' - .format(self.nodeid, self.group, self.host, self.port, self.event)) - - def __repr__(self) -> str: - return str(self) - - -class CitusHandler(Thread): - - def __init__(self, postgresql: 'Postgresql', config: Optional[Dict[str, Union[str, int]]]) -> None: - super(CitusHandler, self).__init__() - self.daemon = True - self._postgresql = postgresql - self._config = config - if config: - self._connection = postgresql.connection_pool.get( - 'citus', {'dbname': config['database'], - 'options': '-c statement_timeout=0 -c idle_in_transaction_session_timeout=0'}) - self._pg_dist_node: Dict[int, PgDistNode] = {} # Cache of pg_dist_node: {groupid: PgDistNode()} - self._tasks: List[PgDistNode] = [] # Requests to change pg_dist_node, every task is a `PgDistNode` - self._in_flight: Optional[PgDistNode] = None # Reference to the `PgDistNode` being changed in a transaction - self._schedule_load_pg_dist_node = True # Flag that "pg_dist_node" should be queried from the database - self._condition = Condition() # protects _pg_dist_node, _tasks, _in_flight, and _schedule_load_pg_dist_node - self.schedule_cache_rebuild() - - def is_enabled(self) -> bool: - return isinstance(self._config, dict) - - def group(self) -> Optional[int]: - return int(self._config['group']) if isinstance(self._config, dict) else None - - def is_coordinator(self) -> bool: - return self.is_enabled() and self.group() == CITUS_COORDINATOR_GROUP_ID - - def is_worker(self) -> bool: - return self.is_enabled() and not self.is_coordinator() - - def schedule_cache_rebuild(self) -> None: - with self._condition: - self._schedule_load_pg_dist_node = True - - def on_demote(self) -> None: - with self._condition: - self._pg_dist_node.clear() - empty_tasks: List[PgDistNode] = [] - self._tasks[:] = empty_tasks - self._in_flight = None - - def query(self, sql: str, *params: Any) -> List[Tuple[Any, ...]]: - try: - logger.debug('query(%s, %s)', sql, params) - return self._connection.query(sql, *params) - except Exception as e: - logger.error('Exception when executing query "%s", (%s): %r', sql, params, e) - self._connection.close() - with self._condition: - self._in_flight = None - self.schedule_cache_rebuild() - raise e - - def load_pg_dist_node(self) -> bool: - """Read from the `pg_dist_node` table and put it into the local cache""" - - with self._condition: - if not self._schedule_load_pg_dist_node: - return True - self._schedule_load_pg_dist_node = False - - try: - rows = self.query("SELECT nodeid, groupid, nodename, nodeport, noderole" - " FROM pg_catalog.pg_dist_node WHERE noderole = 'primary'") - except Exception: - return False - - with self._condition: - self._pg_dist_node = {r[1]: PgDistNode(r[1], r[2], r[3], 'after_promote', r[0]) for r in rows} - return True - - def sync_pg_dist_node(self, cluster: Cluster) -> None: - """Maintain the `pg_dist_node` from the coordinator leader every heartbeat loop. - - We can't always rely on REST API calls from worker nodes in order - to maintain `pg_dist_node`, therefore at least once per heartbeat - loop we make sure that workes registered in `self._pg_dist_node` - cache are matching the cluster view from DCS by creating tasks - the same way as it is done from the REST API.""" - - if not self.is_coordinator(): - return - - with self._condition: - if not self.is_alive(): - self.start() - - self.add_task('after_promote', CITUS_COORDINATOR_GROUP_ID, self._postgresql.connection_string) - - for group, worker in cluster.workers.items(): - leader = worker.leader - if leader and leader.conn_url\ - and leader.data.get('role') in ('master', 'primary') and leader.data.get('state') == 'running': - self.add_task('after_promote', group, leader.conn_url) - - def find_task_by_group(self, group: int) -> Optional[int]: - for i, task in enumerate(self._tasks): - if task.group == group: - return i - - def pick_task(self) -> Tuple[Optional[int], Optional[PgDistNode]]: - """Returns the tuple(i, task), where `i` - is the task index in the self._tasks list - - Tasks are picked by following priorities: - - 1. If there is already a transaction in progress, pick a task - that that will change already affected worker primary. - 2. If the coordinator address should be changed - pick a task - with group=0 (coordinators are always in group 0). - 3. Pick a task that is the oldest (first from the self._tasks) - """ - - with self._condition: - if self._in_flight: - i = self.find_task_by_group(self._in_flight.group) - else: - while True: - i = self.find_task_by_group(CITUS_COORDINATOR_GROUP_ID) # set_coordinator - if i is None and self._tasks: - i = 0 - if i is None: - break - task = self._tasks[i] - if task == self._pg_dist_node.get(task.group): - self._tasks.pop(i) # nothing to do because cached version of pg_dist_node already matches - else: - break - task = self._tasks[i] if i is not None else None - - # When tasks are added it could happen that self._pg_dist_node - # wasn't ready (self._schedule_load_pg_dist_node is False) - # and hence the nodeid wasn't filled. - if task and task.group in self._pg_dist_node: - task.nodeid = self._pg_dist_node[task.group].nodeid - return i, task - - def update_node(self, task: PgDistNode) -> None: - if task.nodeid is not None: - self.query('SELECT pg_catalog.citus_update_node(%s, %s, %s, true, %s)', - task.nodeid, task.host, task.port, task.cooldown) - elif task.event != 'before_demote': - task.nodeid = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')", - task.host, task.port, task.group)[0][0] - - def process_task(self, task: PgDistNode) -> bool: - """Updates a single row in `pg_dist_node` table, optionally in a transaction. - - The transaction is started if we do a demote of the worker node or before promoting the other worker if - there is no transaction in progress. And, the transaction is committed when the switchover/failover completed. - - .. note: - The maximum lifetime of the transaction in progress is controlled outside of this method. - - .. note: - Read access to `self._in_flight` isn't protected because we know it can't be changed outside of our thread. - - :param task: reference to a :class:`PgDistNode` object that represents a row to be updated/created. - :returns: `True` if the row was succesfully created/updated or transaction in progress - was committed as an indicator that the `self._pg_dist_node` cache should be updated, - or, if the new transaction was opened, this method returns `False`. - """ - - if task.event == 'after_promote': - # The after_promote may happen without previous before_demote and/or - # before_promore. In this case we just call self.update_node() method. - # If there is a transaction in progress, it could be that it already did - # required changes and we can simply COMMIT. - if not self._in_flight or self._in_flight.host != task.host or self._in_flight.port != task.port: - self.update_node(task) - if self._in_flight: - self.query('COMMIT') - return True - else: # before_demote, before_promote - if task.timeout: - task.deadline = time.time() + task.timeout - if not self._in_flight: - self.query('BEGIN') - self.update_node(task) - return False - - def process_tasks(self) -> None: - while True: - # Read access to `_in_flight` isn't protected because we know it can't be changed outside of our thread. - if not self._in_flight and not self.load_pg_dist_node(): - break - - i, task = self.pick_task() - if not task or i is None: - break - try: - update_cache = self.process_task(task) - except Exception as e: - logger.error('Exception when working with pg_dist_node: %r', e) - update_cache = None - with self._condition: - if self._tasks: - if update_cache: - self._pg_dist_node[task.group] = task - - if update_cache is False: # an indicator that process_tasks has started a transaction - self._in_flight = task - else: - self._in_flight = None - - if id(self._tasks[i]) == id(task): - self._tasks.pop(i) - task.wakeup() - - def run(self) -> None: - while True: - try: - with self._condition: - if self._schedule_load_pg_dist_node: - timeout = -1 - elif self._in_flight: - timeout = self._in_flight.deadline - time.time() if self._tasks else None - else: - timeout = -1 if self._tasks else None - - if timeout is None or timeout > 0: - self._condition.wait(timeout) - elif self._in_flight: - logger.warning('Rolling back transaction. Last known status: %s', self._in_flight) - self.query('ROLLBACK') - self._in_flight = None - self.process_tasks() - except Exception: - logger.exception('run') - - def _add_task(self, task: PgDistNode) -> bool: - with self._condition: - i = self.find_task_by_group(task.group) - - # The `PgDistNode.timeout` == None is an indicator that it was scheduled from the sync_pg_dist_node(). - if task.timeout is None: - # We don't want to override the already existing task created from REST API. - if i is not None and self._tasks[i].timeout is not None: - return False - - # There is a little race condition with tasks created from REST API - the call made "before" the member - # key is updated in DCS. Therefore it is possible that :func:`sync_pg_dist_node` will try to create a - # task based on the outdated values of "state"/"role". To solve it we introduce an artificial timeout. - # Only when the timeout is reached new tasks could be scheduled from sync_pg_dist_node() - if self._in_flight and self._in_flight.group == task.group and self._in_flight.timeout is not None\ - and self._in_flight.deadline > time.time(): - return False - - # Override already existing task for the same worker group - if i is not None: - if task != self._tasks[i]: - logger.debug('Overriding existing task: %s != %s', self._tasks[i], task) - self._tasks[i] = task - self._condition.notify() - return True - # Add the task to the list if Worker node state is different from the cached `pg_dist_node` - elif self._schedule_load_pg_dist_node or task != self._pg_dist_node.get(task.group)\ - or self._in_flight and task.group == self._in_flight.group: - logger.debug('Adding the new task: %s', task) - self._tasks.append(task) - self._condition.notify() - return True - return False - - def add_task(self, event: str, group: int, conn_url: str, - timeout: Optional[float] = None, cooldown: Optional[float] = None) -> Optional[PgDistNode]: - try: - r = urlparse(conn_url) - except Exception as e: - return logger.error('Failed to parse connection url %s: %r', conn_url, e) - host = r.hostname - if host: - port = r.port or 5432 - task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown) - return task if self._add_task(task) else None - - def handle_event(self, cluster: Cluster, event: Dict[str, Any]) -> None: - if not self.is_alive(): - return - - worker = cluster.workers.get(event['group']) - if not (worker and worker.leader and worker.leader.name == event['leader'] and worker.leader.conn_url): - return - - task = self.add_task(event['type'], event['group'], - worker.leader.conn_url, - event['timeout'], event['cooldown'] * 1000) - if task and event['type'] == 'before_demote': - task.wait() - - def bootstrap(self) -> None: - if not isinstance(self._config, dict): # self.is_enabled() - return - - conn_kwargs = {**self._postgresql.connection_pool.conn_kwargs, - 'options': '-c synchronous_commit=local -c statement_timeout=0'} - if self._config['database'] != self._postgresql.database: - conn = connect(**conn_kwargs) - try: - with conn.cursor() as cur: - cur.execute('CREATE DATABASE {0}'.format( - quote_ident(self._config['database'], conn)).encode('utf-8')) - except ProgrammingError as exc: - if exc.diag.sqlstate == '42P04': # DuplicateDatabase - logger.debug('Exception when creating database: %r', exc) - else: - raise exc - finally: - conn.close() - - conn_kwargs['dbname'] = self._config['database'] - conn = connect(**conn_kwargs) - try: - with conn.cursor() as cur: - cur.execute('CREATE EXTENSION IF NOT EXISTS citus') - - superuser = self._postgresql.config.superuser - params = {k: superuser[k] for k in ('password', 'sslcert', 'sslkey') if k in superuser} - if params: - cur.execute("INSERT INTO pg_catalog.pg_dist_authinfo VALUES" - "(0, pg_catalog.current_user(), %s)", - (self._postgresql.config.format_dsn(params),)) - - if self.is_coordinator(): - r = urlparse(self._postgresql.connection_string) - cur.execute("SELECT pg_catalog.citus_set_coordinator_host(%s, %s, 'primary', 'default')", - (r.hostname, r.port or 5432)) - finally: - conn.close() - - def adjust_postgres_gucs(self, parameters: Dict[str, Any]) -> None: - if not self.is_enabled(): - return - - # citus extension must be on the first place in shared_preload_libraries - shared_preload_libraries = list(filter( - lambda el: el and el != 'citus', - [p.strip() for p in parameters.get('shared_preload_libraries', '').split(',')])) - parameters['shared_preload_libraries'] = ','.join(['citus'] + shared_preload_libraries) - - # if not explicitly set Citus overrides max_prepared_transactions to max_connections*2 - if parameters['max_prepared_transactions'] == 0: - parameters['max_prepared_transactions'] = parameters['max_connections'] * 2 - - # Resharding in Citus implemented using logical replication - parameters['wal_level'] = 'logical' - - # Sometimes Citus needs to connect to the local postgres. We will do it the same way as Patroni does. - parameters['citus.local_hostname'] = self._postgresql.connection_pool.conn_kwargs.get('host', 'localhost') - - def ignore_replication_slot(self, slot: Dict[str, str]) -> bool: - if isinstance(self._config, dict) and self._postgresql.is_primary() and\ - slot['type'] == 'logical' and slot['database'] == self._config['database']: - m = CITUS_SLOT_NAME_RE.match(slot['name']) - return bool(m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin']) - return False diff --git a/tests/test_api.py b/tests/test_api.py index 55e0016bc..5245a1b13 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -159,6 +159,7 @@ class MockPatroni(object): noloadbalance = PropertyMock(return_value=False) scheduled_restart = {'schedule': future_restart_time, 'postmaster_start_time': postgresql.postmaster_start_time()} + multisite = Mock() @staticmethod def sighup_handler(): diff --git a/tests/test_ha.py b/tests/test_ha.py index 9b6f90b62..18c17d143 100644 --- a/tests/test_ha.py +++ b/tests/test_ha.py @@ -13,6 +13,7 @@ from patroni.dcs.etcd import AbstractEtcdClientWithFailover from patroni.exceptions import DCSError, PatroniFatalException, PostgresConnectionException from patroni.ha import _MemberStatus, Ha +from patroni.multisite import SingleSiteController from patroni.postgresql import Postgresql from patroni.postgresql.bootstrap import Bootstrap from patroni.postgresql.callback_executor import CallbackAction @@ -156,6 +157,7 @@ def __init__(self, p, d): self.watchdog = Watchdog(self.config) self.request = lambda *args, **kwargs: requests_get(args[0].api_url, *args[1:], **kwargs) self.failover_priority = 1 + self.multisite = SingleSiteController() def run_async(self, func, args=()):