From b9ee0f5896bc3692a0b139461936582516d62298 Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Sat, 2 Apr 2022 14:55:08 +0100 Subject: [PATCH] Make sure that Postgres always get a consistent snapshot --- pipelinewise/fastsync/commons/tap_postgres.py | 272 +++++++++++------- .../commons/test_fastsync_tap_postgres.py | 32 ++- 2 files changed, 197 insertions(+), 107 deletions(-) diff --git a/pipelinewise/fastsync/commons/tap_postgres.py b/pipelinewise/fastsync/commons/tap_postgres.py index 66b094e72..22cdc247a 100644 --- a/pipelinewise/fastsync/commons/tap_postgres.py +++ b/pipelinewise/fastsync/commons/tap_postgres.py @@ -1,12 +1,15 @@ +from dataclasses import dataclass import datetime import decimal import logging import re import sys +from time import time import psycopg2 +import psycopg2.errors import psycopg2.extras -from typing import Dict +from typing import Dict, Optional from . import utils, split_gzip @@ -15,6 +18,21 @@ LOGGER = logging.getLogger(__name__) +@dataclass +class SlotCreationResult: + """Class to store replication slot creation result""" + slot_name: str + start_point: int + snapshot_name: str + plugin_name: str + + +def parse_lsn(lsn: str) -> int: + """Parse string LSN format to integer""" + file, index = lsn.split('/') + return (int(file, 16) << 32) + int(index, 16) + + class FastSyncTapPostgres: """ Common functions for fastsync from a Postgres database @@ -25,9 +43,8 @@ def __init__(self, connection_config, tap_type_to_target_type, target_quote=None self.tap_type_to_target_type = tap_type_to_target_type self.target_quote = target_quote self.conn = None - self.curr = None self.primary_host_conn = None - self.primary_host_curr = None + self.version = None @staticmethod def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'): @@ -78,11 +95,12 @@ def __get_slot_name( try: # Backward compatibility: try to locate existing v15 slot first. PPW <= 0.15.0 - with connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute( - f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v15}';" - ) - v15_slots_count = cur.rowcount + with connection: + with connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute( + f"SELECT * FROM pg_replication_slots WHERE slot_name = '{slot_name_v15}';" + ) + v15_slots_count = cur.rowcount except psycopg2.Error: LOGGER.exception('Error while looking for slots', exc_info=sys.exc_info()) @@ -104,9 +122,11 @@ def drop_slot(cls, connection_config: Dict) -> None: """ LOGGER.info('Attempting to drop slot ...') - LOGGER.debug('Creating a connection to Primary server ..') - connection = cls.get_connection(connection_config, prioritize_primary=True) - LOGGER.debug('Connection to Primary server created.') + LOGGER.debug('Creating a replication connection to Primary server ..') + connection = cls.get_connection( + connection_config, prioritize_primary=True, replication=True + ) + LOGGER.debug('Replication connection to Primary server created.') try: slot_name = cls.__get_slot_name( @@ -115,18 +135,23 @@ def drop_slot(cls, connection_config: Dict) -> None: LOGGER.info('Dropping the slot "%s"', slot_name) # drop the replication host - with connection.cursor() as cur: - cur.execute( - f'SELECT pg_drop_replication_slot(slot_name) ' - f"FROM pg_replication_slots WHERE slot_name = '{slot_name}';" - ) - LOGGER.info('Number of dropped slots: %s', cur.rowcount) - + with connection as conn: + with conn.cursor() as cur: + try: + cur.drop_replication_slot(slot_name) + # No idea why Pylint thinks this isn't a member... + except psycopg2.errors.UndefinedObject: # pylint: disable=no-member + pass finally: connection.close() @classmethod - def get_connection(cls, connection_config: Dict, prioritize_primary: bool = False): + def get_connection( + cls, + connection_config: Dict, + prioritize_primary: bool = False, + replication: bool = False, + ): """ Class method to create a pg connection instance with autocommit enabled Connection is either to the primary or a replica if its credentials are given @@ -168,10 +193,15 @@ def get_connection(cls, connection_config: Dict, prioritize_primary: bool = Fals if 'ssl' in connection_config and connection_config['ssl'] == 'true': conn_string += " sslmode='require'" - conn = psycopg2.connect(conn_string) + kwargs = {} + if replication: + kwargs = {'connection_factory': psycopg2.extras.LogicalReplicationConnection} - # Set connection to autocommit - conn.autocommit = True + conn = psycopg2.connect(conn_string, **kwargs) + + if not replication: + # Cannot set autocommit on repliication connections + conn.autocommit = True LOGGER.info('Connection to PGSQL server established') @@ -182,9 +212,8 @@ def open_connection(self): Open connection """ self.conn = self.get_connection( - self.connection_config, prioritize_primary=False + self.connection_config, prioritize_primary=False, replication=False ) - self.curr = self.conn.cursor() def close_connection(self): """ @@ -197,8 +226,9 @@ def query(self, query, params=None): Run query """ LOGGER.info('Running query: %s', query) - with self.conn as connection: - with connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + + with self.conn: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(query, params) if cur.rowcount > 0: @@ -211,8 +241,8 @@ def primary_host_query(self, query, params=None): Run query on the primary host """ LOGGER.info('Running query: %s', query) - with self.primary_host_conn as connection: - with connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + with self.primary_host_conn as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(query, params) if cur.rowcount > 0: @@ -220,8 +250,7 @@ def primary_host_query(self, query, params=None): return [] - # pylint: disable=no-member - def create_replication_slot(self): + def create_replication_slot(self) -> SlotCreationResult: """ Create replication slot on the primary host @@ -237,23 +266,75 @@ def create_replication_slot(self): database by multiple taps. If that the case then you need to drop the old replication slot and full-resync the new taps. """ - try: - slot_name = self.__get_slot_name( - self.primary_host_conn, - self.connection_config['dbname'], - self.connection_config['tap_id'], - ) + slot_name = self.__get_slot_name( + self.primary_host_conn, + self.connection_config['dbname'], + self.connection_config['tap_id'], + ) - # Create the replication host - self.primary_host_query( - f"SELECT * FROM pg_create_logical_replication_slot('{slot_name}', 'wal2json')" - ) - except Exception as exc: - # ERROR: replication slot already exists SQL state: 42710 - if hasattr(exc, 'pgcode') and exc.pgcode == '42710': - pass + # Create the replication host + with self.primary_host_conn as conn: + with conn.cursor() as cur: + cur.create_replication_slot(slot_name, output_plugin='wal2json') + res = cur.fetchone() + + return SlotCreationResult( + slot_name=res[0], + start_point=parse_lsn(res[1]), + snapshot_name=res[2], + plugin_name=res[3], + ) + + def get_current_lsn(self) -> int: + """Get the LSN of the most recent WAL availiable to the snapshotting process""" + # is replica_host set ? + if self.connection_config.get('replica_host'): + # Get latest applied lsn from replica_host + if self.version >= 100000: + result = self.query('SELECT pg_last_wal_replay_lsn() AS current_lsn') + elif self.version >= 90400: + result = self.query('SELECT pg_last_xlog_replay_location() AS current_lsn') else: - raise exc + raise Exception( + 'Logical replication not supported before PostgreSQL 9.4' + ) + else: + # Get current lsn from primary host + if self.version >= 100000: + result = self.query('SELECT pg_current_wal_lsn() AS current_lsn') + elif self.version >= 90400: + result = self.query('SELECT pg_current_xlog_location() AS current_lsn') + else: + raise Exception('Logical replication not supported before PostgreSQL 9.4') + + return parse_lsn(result[0].get('current_lsn')) + + def get_confirmed_flush_lsn(self) -> int: + """ + Get this last flushed LSN for the replication slot. + + For Postgres <9.6 with defaults to the restart_lsn and confirmed_flus_lsn + is not availiable. + """ + slot_name = self.__get_slot_name( + self.primary_host_conn, + self.connection_config['dbname'], + self.connection_config['tap_id'], + ) + + res = self.primary_host_query( + 'SELECT * ' + 'FROM pg_replication_slots ' + f"WHERE slot_name = '{slot_name}' " + "AND plugin = 'wal2json' " + "AND slot_type = 'logical'" + )[0] + + # confirmed_flush_lsn was introduced in Postgres 9.6 so fallback to + # restart_lsn if needed. This is an attempted to fallback to a point + # at which there are no pending transactions (as this will still + # result in transactions since confirmed_flush_lsn being delivered). + return parse_lsn(res.get('confirmed_flush_lsn', res['restart_lsn'])) # pylint: disable=too-many-branches,no-member,chained-comparison def fetch_current_log_pos(self): @@ -263,65 +344,51 @@ def fetch_current_log_pos(self): # Create replication slot dedicated connection # Always use Primary server for creating replication_slot self.primary_host_conn = self.get_connection( - self.connection_config, prioritize_primary=True + self.connection_config, prioritize_primary=True, replication=True ) - self.primary_host_curr = self.primary_host_conn.cursor() # Make sure PostgreSQL version is 9.4 or higher result = self.primary_host_query( "SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'" ) - version = result[0].get('version') + self.version = result[0].get('version') # Do not allow minor versions with PostgreSQL BUG #15114 - if (version >= 110000) and (version < 110002): + if (self.version >= 110000) and (self.version < 110002): raise Exception('PostgreSQL upgrade required to minor version 11.2') - if (version >= 100000) and (version < 100007): + if (self.version >= 100000) and (self.version < 100007): raise Exception('PostgreSQL upgrade required to minor version 10.7') - if (version >= 90600) and (version < 90612): + if (self.version >= 90600) and (self.version < 90612): raise Exception('PostgreSQL upgrade required to minor version 9.6.12') - if (version >= 90500) and (version < 90516): + if (self.version >= 90500) and (self.version < 90516): raise Exception('PostgreSQL upgrade required to minor version 9.5.16') - if (version >= 90400) and (version < 90421): + if (self.version >= 90400) and (self.version < 90421): raise Exception('PostgreSQL upgrade required to minor version 9.4.21') - if version < 90400: + if self.version < 90400: raise Exception('Logical replication not supported before PostgreSQL 9.4') - # Create replication slot - self.create_replication_slot() + try: + # Try to create replication slot. + slot_creation_info = self.create_replication_slot() + start_lsn = slot_creation_info.start_point + except psycopg2.errors.DuplicateObject: + # If we've already created the replication slot we start streaming + # from the last flushed LSN. This point should result in a + # consistent view of the database as PipelineWise will only send + # feedback for LSNs where transactions have been committed. + start_lsn = self.get_confirmed_flush_lsn() # Close replication slot dedicated connection self.primary_host_conn.close() - # is replica_host set ? + # If we are performing the initial snapshot from a replica we must + # let the replica catch up to at least the point where the logical + # replication will pick up from when the snapshot is complete. if self.connection_config.get('replica_host'): - # Get latest applied lsn from replica_host - if version >= 100000: - result = self.query('SELECT pg_last_wal_replay_lsn() AS current_lsn') - elif version >= 90400: - result = self.query( - 'SELECT pg_last_xlog_replay_location() AS current_lsn' - ) - else: - raise Exception( - 'Logical replication not supported before PostgreSQL 9.4' - ) - else: - # Get current lsn from primary host - if version >= 100000: - result = self.query('SELECT pg_current_wal_lsn() AS current_lsn') - elif version >= 90400: - result = self.query('SELECT pg_current_xlog_location() AS current_lsn') - else: - raise Exception( - 'Logical replication not supported before PostgreSQL 9.4' - ) - - current_lsn = result[0].get('current_lsn') - file, index = current_lsn.split('/') - lsn = (int(file, 16) << 32) + int(index, 16) + while start_lsn > self.get_current_lsn(): + time.sleep(1) - return {'lsn': lsn, 'version': 1} + return {'lsn': start_lsn, 'version': 1} # pylint: disable=invalid-name def fetch_current_incremental_key_pos(self, table, replication_key): @@ -461,14 +528,14 @@ def map_column_types_to_target(self, table_name): # pylint: disable=too-many-arguments def copy_table( self, - table_name, - path, - max_num=None, - date_type='date', - split_large_files=False, - split_file_chunk_size_mb=1000, - split_file_max_chunks=20, - compress=True, + table_name: str, + path: str, + max_num: Optional[int] = None, + date_type: str = 'date', + split_large_files: bool = False, + split_file_chunk_size_mb: int = 1000, + split_file_max_chunks: int = 20, + compress: bool = True, ): """ Export data from table to a zipped csv @@ -480,8 +547,9 @@ def copy_table( split_file_chunk_size_mb: File chunk sizes if `split_large_files` enabled. (Default: 1000) split_file_max_chunks: Max number of chunks if `split_large_files` enabled. (Default: 20) """ - table_columns = self.get_table_columns(table_name, max_num, date_type) - column_safe_sql_values = [c.get('safe_sql_value') for c in table_columns] + column_safe_sql_values = [ + c.get('safe_sql_value') for c in self.get_table_columns(table_name, max_num, date_type) + ] # If self.get_table_columns returns zero row then table not exist if len(column_safe_sql_values) == 0: @@ -489,14 +557,14 @@ def copy_table( schema_name, table_name = table_name.split('.') - sql = """COPY (SELECT {} - ,now() AT TIME ZONE 'UTC' - ,now() AT TIME ZONE 'UTC' - ,null - FROM {}."{}") TO STDOUT with CSV DELIMITER ',' - """.format( - ','.join(column_safe_sql_values), schema_name, table_name + sql = ( + f"COPY (SELECT {','.join(column_safe_sql_values)} " + ",now() AT TIME ZONE 'UTC' " + ",now() AT TIME ZONE 'UTC' " + ',null ' + f"FROM {schema_name}.\"{table_name}\") TO STDOUT with CSV DELIMITER ','" ) + LOGGER.info('Exporting data: %s', sql) gzip_splitter = split_gzip.open( @@ -508,4 +576,6 @@ def copy_table( ) with gzip_splitter as split_gzip_files: - self.curr.copy_expert(sql, split_gzip_files, size=131072) + with self.conn: + with self.conn.cursor() as cur: + cur.copy_expert(sql, split_gzip_files, size=131072) diff --git a/tests/units/fastsync/commons/test_fastsync_tap_postgres.py b/tests/units/fastsync/commons/test_fastsync_tap_postgres.py index 269e4e6a0..b28726aef 100644 --- a/tests/units/fastsync/commons/test_fastsync_tap_postgres.py +++ b/tests/units/fastsync/commons/test_fastsync_tap_postgres.py @@ -74,10 +74,15 @@ def execute_mock(query): # mock cursor with execute method cursor_mock = MagicMock().return_value cursor_mock.__enter__.return_value.execute.side_effect = execute_mock + cursor_mock.__enter__.return_value.fetchone.return_value = ( + 'pipelinewise_test_database_test_tap', '242FC/BA84A740', '00000009-0192A151-1', 'wal2json' + ) type(cursor_mock.__enter__.return_value).rowcount = PropertyMock(return_value=0) # mock PG connection instance with ability to open cursor pg_con = Mock() + pg_con.__enter__ = lambda *_: pg_con + pg_con.__exit__ = lambda *_: None pg_con.cursor.return_value = cursor_mock self.postgres.primary_host_conn = pg_con @@ -85,8 +90,10 @@ def execute_mock(query): self.postgres.create_replication_slot() assert self.postgres.executed_queries_primary_host == [ "SELECT * FROM pg_replication_slots WHERE slot_name = 'pipelinewise_test_database';", - "SELECT * FROM pg_create_logical_replication_slot('pipelinewise_test_database_test_tap', 'wal2json')", ] + cursor_mock.__enter__.return_value.create_replication_slot.assert_called_once_with( + 'pipelinewise_test_database_test_tap', output_plugin='wal2json' + ) def test_create_replication_slot_2(self): """ @@ -100,10 +107,15 @@ def execute_mock(query): # mock cursor with execute method cursor_mock = MagicMock().return_value cursor_mock.__enter__.return_value.execute.side_effect = execute_mock + cursor_mock.__enter__.return_value.fetchone.return_value = ( + 'pipelinewise_test_database_test_tap', '242FC/BA84A740', '00000009-0192A151-1', 'wal2json' + ) type(cursor_mock.__enter__.return_value).rowcount = PropertyMock(return_value=1) # mock PG connection instance with ability to open cursor pg_con = Mock() + pg_con.__enter__ = lambda *_: pg_con + pg_con.__exit__ = lambda *_: None pg_con.cursor.return_value = cursor_mock self.postgres.primary_host_conn = pg_con @@ -111,8 +123,10 @@ def execute_mock(query): self.postgres.create_replication_slot() assert self.postgres.executed_queries_primary_host == [ "SELECT * FROM pg_replication_slots WHERE slot_name = 'pipelinewise_test_database';", - "SELECT * FROM pg_create_logical_replication_slot('pipelinewise_test_database', 'wal2json')", ] + cursor_mock.__enter__.return_value.create_replication_slot.assert_called_once_with( + 'pipelinewise_test_database', output_plugin='wal2json' + ) @patch('pipelinewise.fastsync.commons.tap_postgres.psycopg2.connect') def test_get_connection_to_primary(self, connect_mock): @@ -250,6 +264,8 @@ def execute_mock(query): # mock PG connection instance with ability to open cursor pg_con = Mock() + pg_con.__enter__ = lambda *_: pg_con + pg_con.__exit__ = lambda *_: None pg_con.cursor.return_value = cursor_mock connect_mock.return_value = pg_con @@ -258,9 +274,10 @@ def execute_mock(query): assert self.postgres.executed_queries_primary_host == [ "SELECT * FROM pg_replication_slots WHERE slot_name = 'pipelinewise_my_db';", - 'SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE ' - "slot_name = 'pipelinewise_my_db';", ] + cursor_mock.__enter__.return_value.drop_replication_slot.assert_called_once_with( + 'pipelinewise_my_db' + ) @patch('pipelinewise.fastsync.commons.tap_postgres.psycopg2.connect') def test_drop_slot_v16(self, connect_mock): @@ -291,6 +308,8 @@ def execute_mock(query): # mock PG connection instance with ability to open cursor pg_con = Mock() + pg_con.__enter__ = lambda *_: pg_con + pg_con.__exit__ = lambda *_: None pg_con.cursor.return_value = cursor_mock connect_mock.return_value = pg_con @@ -299,6 +318,7 @@ def execute_mock(query): assert self.postgres.executed_queries_primary_host == [ "SELECT * FROM pg_replication_slots WHERE slot_name = 'pipelinewise_my_db';", - 'SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE ' - "slot_name = 'pipelinewise_my_db_tap_test';", ] + cursor_mock.__enter__.return_value.drop_replication_slot.assert_called_once_with( + 'pipelinewise_my_db_tap_test' + )