Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] OWBaseSql: Base widget for connecting to a database #4083

Merged
merged 1 commit into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions Orange/widgets/utils/owbasesql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Type
from collections import OrderedDict

from AnyQt.QtWidgets import QLineEdit, QSizePolicy

from Orange.data import Table
from Orange.data.sql.backend import Backend
from Orange.data.sql.backend.base import BackendError
from Orange.widgets import gui, report
from Orange.widgets.credentials import CredentialManager
from Orange.widgets.settings import Setting
from Orange.widgets.utils.signals import Output
from Orange.widgets.widget import OWWidget, Msg


class OWBaseSql(OWWidget, openclass=True):
"""Base widget for connecting to a database.
Override `get_backend` when subclassing to get corresponding backend.
"""
class Outputs:
data = Output("Data", Table)

class Error(OWWidget.Error):
connection = Msg("{}")

want_main_area = False
resizing_enabled = False

host = Setting(None) # type: Optional[str]
port = Setting(None) # type: Optional[str]
database = Setting(None) # type: Optional[str]
schema = Setting(None) # type: Optional[str]
username = ""
password = ""

def __init__(self):
super().__init__()
self.backend = None # type: Optional[Backend]
self.data_desc_table = None # type: Optional[Table]
self.database_desc = None # type: Optional[OrderedDict]
self._setup_gui()
self.connect()

def _setup_gui(self):
self.controlArea.setMinimumWidth(360)

vbox = gui.vBox(self.controlArea, "Server", addSpace=True)
self.serverbox = gui.vBox(vbox)
self.servertext = QLineEdit(self.serverbox)
self.servertext.setPlaceholderText("Server")
self.servertext.setToolTip("Server")
self.servertext.editingFinished.connect(self._load_credentials)
if self.host:
self.servertext.setText(self.host if not self.port else
"{}:{}".format(self.host, self.port))
self.serverbox.layout().addWidget(self.servertext)

self.databasetext = QLineEdit(self.serverbox)
self.databasetext.setPlaceholderText("Database[/Schema]")
self.databasetext.setToolTip("Database or optionally Database/Schema")
if self.database:
self.databasetext.setText(
self.database if not self.schema else
"{}/{}".format(self.database, self.schema))
self.serverbox.layout().addWidget(self.databasetext)
self.usernametext = QLineEdit(self.serverbox)
self.usernametext.setPlaceholderText("Username")
self.usernametext.setToolTip("Username")

self.serverbox.layout().addWidget(self.usernametext)
self.passwordtext = QLineEdit(self.serverbox)
self.passwordtext.setPlaceholderText("Password")
self.passwordtext.setToolTip("Password")
self.passwordtext.setEchoMode(QLineEdit.Password)

self.serverbox.layout().addWidget(self.passwordtext)

self._load_credentials()

self.connectbutton = gui.button(self.serverbox, self, "Connect",
callback=self.connect)
self.connectbutton.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)

def _load_credentials(self):
self._parse_host_port()
cm = self._credential_manager(self.host, self.port)
self.username = cm.username
self.password = cm.password

if self.username:
self.usernametext.setText(self.username)
if self.password:
self.passwordtext.setText(self.password)

def _save_credentials(self):
cm = self._credential_manager(self.host, self.port)
cm.username = self.username or ""
cm.password = self.password or ""

@staticmethod
def _credential_manager(host, port):
return CredentialManager("SQL Table: {}:{}".format(host, port))

def _parse_host_port(self):
hostport = self.servertext.text().split(":")
self.host = hostport[0]
self.port = hostport[1] if len(hostport) == 2 else None

def _check_db_settings(self):
self._parse_host_port()
self.database, _, self.schema = self.databasetext.text().partition("/")
self.username = self.usernametext.text() or None
self.password = self.passwordtext.text() or None

def connect(self):
self.clear()
self._check_db_settings()
if not self.host or not self.database:
return
try:
backend = self.get_backend()
if backend is None:
return
self.backend = backend(dict(
host=self.host,
port=self.port,
database=self.database,
user=self.username,
password=self.password
))
self.on_connection_success()
except BackendError as err:
self.on_connection_error(err)

def get_backend(self) -> Type[Backend]:
"""
Derived widgets should override this to get corresponding backend.

Returns
-------
backend: Type[Backend]
"""
raise NotImplementedError

def on_connection_success(self):
self._save_credentials()
self.database_desc = OrderedDict((
("Host", self.host), ("Port", self.port),
("Database", self.database), ("User name", self.username)
))

def on_connection_error(self, err):
error = str(err).split("\n")[0]
self.Error.connection(error)

def open_table(self):
data = self.get_table()
self.data_desc_table = data
self.Outputs.data.send(data)
info = str(len(data)) if data else self.info.NoOutput
self.info.set_output_summary(info)

def get_table(self) -> Table:
"""
Derived widgets should override this to get corresponding table.

Returns
-------
table: Table
"""
raise NotImplementedError

def clear(self):
self.Error.connection.clear()
self.database_desc = None
self.data_desc_table = None
self.Outputs.data.send(None)
self.info.set_output_summary(self.info.NoOutput)

def send_report(self):
if not self.database_desc:
self.report_paragraph("No database connection.")
return
self.report_items("Database", self.database_desc)
if self.data_desc_table:
self.report_items("Data",
report.describe_data(self.data_desc_table))
96 changes: 96 additions & 0 deletions Orange/widgets/utils/tests/test_owbasesql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# pylint: disable=missing-docstring
import unittest
from unittest.mock import Mock
from collections import OrderedDict
from types import SimpleNamespace

from Orange.data import Table
from Orange.data.sql.backend import Backend
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.utils.owbasesql import OWBaseSql
from Orange.data.sql.backend.base import BackendError


USERNAME = "UN"
PASSWORD = "PASS"


class BrokenBackend(Backend): # pylint: disable=abstract-method
def __init__(self, connection_params):
super().__init__(connection_params)
raise BackendError("Error connecting to DB.")


class TestableSqlWidget(OWBaseSql):
name = "SQL"

def __init__(self):
self.mocked_backend = Mock()
super().__init__()

def get_backend(self):
return self.mocked_backend

def get_table(self) -> Table:
return Table("iris")

@staticmethod
def _credential_manager(_, __): # pylint: disable=arguments-differ
return SimpleNamespace(username=USERNAME, password=PASSWORD)


class TestOWBaseSql(WidgetTest):
def setUp(self):
self.host, self.port, self.db = "host", "port", "DB"
settings = {"host": self.host, "port": self.port,
"database": self.db, "schema": ""}
self.widget = self.create_widget(TestableSqlWidget,
stored_settings=settings)

def test_connect(self):
self.widget.mocked_backend.assert_called_once_with(
{"host": "host", "port": "port", "database": self.db,
"user": USERNAME, "password": PASSWORD})
self.assertDictEqual(
self.widget.database_desc,
OrderedDict((("Host", "host"), ("Port", "port"),
("Database", self.db), ("User name", USERNAME))))

def test_connection_error(self):
self.widget.get_backend = Mock(return_value=BrokenBackend)
self.widget.connectbutton.click()
self.assertTrue(self.widget.Error.connection.is_shown())
self.assertIsNone(self.widget.database_desc)

def test_output(self):
self.widget.open_table()
self.assertIsNotNone(self.get_output(self.widget.Outputs.data))
self.assertIsNotNone(self.widget.data_desc_table)

def test_output_error(self):
self.widget.get_table = lambda: None
self.widget.open_table()
self.assertIsNone(self.get_output(self.widget.Outputs.data))
self.assertIsNone(self.widget.data_desc_table)

def test_missing_database_parameter(self):
self.widget.open_table()
self.widget.databasetext.setText("")
self.widget.mocked_backend.reset_mock()
self.widget.connectbutton.click()
self.widget.mocked_backend.assert_not_called()
self.assertIsNone(self.get_output(self.widget.Outputs.data))
self.assertIsNone(self.widget.data_desc_table)
self.assertFalse(self.widget.Error.connection.is_shown())

def test_report(self):
self.widget.report_button.click() # DB connection
self.widget.open_table()
self.widget.report_button.click() # table
self.widget.databasetext.setText("")
self.widget.connectbutton.click()
self.widget.report_button.click() # empty


if __name__ == "__main__":
unittest.main()