-
Notifications
You must be signed in to change notification settings - Fork 0
/
mysqldb.py
149 lines (106 loc) · 3.67 KB
/
mysqldb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""MySQL database connection class.
This class is used to connect to a MySQL database via an SSH tunnel. It
can be used to run queries and return the results as a Pandas dataframe.
Example:
db = MySQLDB(db_info)
db.connect()
sql = "SELECT * FROM table"
df = db.run_query(sql)
db.disconnect()
print(df)
# Output:
#
# id name
# 0 1 John
# 1 2 Jane
Attributes:
db_info (dict): Dictionary containing database connection information.
tunnel (SSHTunnelForwarder): SSH tunnel object.
connection (pymysql.Connection): MySQL database connection object.
verbose (bool): Whether to print verbose output.
cursor (pymysql.Cursor): MySQL cursor object.
Todo:
* Add support for SSH key authentication
"""
import logging
import sshtunnel
import pandas as pd
from pymysql import connect
from sshtunnel import SSHTunnelForwarder
class MySQLDB:
"""MySQL database connection class."""
def __init__(self, db_info, verbose=False):
self.db_info = db_info
self.tunnel = None
self.connection = None
self.verbose = verbose
self.cursor = None
def open_ssh_tunnel(self):
"""Open an SSH tunnel and connect using a username and password."""
db_info = self.db_info
ssh_host = db_info.get('ssh_host')
if ssh_host:
if self.verbose:
sshtunnel.DEFAULT_LOGLEVEL = logging.DEBUG
self.tunnel = SSHTunnelForwarder(
(ssh_host, 22),
ssh_username=db_info['ssh_username'],
ssh_password=db_info['ssh_password'],
remote_bind_address=('127.0.0.1', 3306),
)
self.tunnel.start()
def mysql_connect(self):
"""Connect to a MySQL server using the SSH tunnel connection
"""
db_info = self.db_info
local_port = self.tunnel.local_bind_port if self.tunnel else None
self.connection = connect(
host=db_info['db_host'],
user=db_info['db_username'],
passwd=db_info['db_password'],
db=db_info['db_name'],
port=local_port
)
def get_cursor(self):
"""Get a cursor object to execute SQL queries on.
:return: Cursor object
"""
if self.cursor is None:
self.cursor = self.connection.cursor()
return self.cursor
def run_query(self, sql: str):
"""Runs a given SQL query via the global database connection.
:param sql: MySQL query
:return: Pandas dataframe containing results
"""
return pd.read_sql_query(sql, self.connection)
def run_query_with_cursor(self, sql: str):
"""Runs a given SQL query via the global database connection.
:param sql: MySQL query
:return: Pandas dataframe containing results
"""
cursor = self.get_cursor()
cursor.execute(sql)
results = cursor.fetchall()
cols = [desc[0] for desc in self.cursor.description]
return pd.DataFrame(results, columns=cols)
def mysql_disconnect(self):
"""Closes the MySQL database connection.
"""
if self.connection:
self.connection.close()
def close_ssh_tunnel(self):
"""Closes the SSH tunnel connection.
"""
if self.tunnel:
self.tunnel.close()
def connect(self):
"""Connects to the database and opens an SSH tunnel if needed.
"""
self.open_ssh_tunnel()
self.mysql_connect()
def disconnect(self):
"""Closes the database connection and SSH tunnel if needed.
"""
self.mysql_disconnect()
self.close_ssh_tunnel()