-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
OWBaseSql: Base widget for connecting to DB
- Loading branch information
Showing
2 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
from typing import Type | ||
from collections import OrderedDict | ||
|
||
from AnyQt.QtWidgets import QLineEdit, QSizePolicy | ||
|
||
from Orange.data import Table | ||
from Orange.data.sql.backend import Backend | ||
from Orange.data.sql.backend.base import BackendError | ||
from Orange.widgets import gui, report | ||
from Orange.widgets.credentials import CredentialManager | ||
from Orange.widgets.settings import Setting | ||
from Orange.widgets.utils.signals import Output | ||
from Orange.widgets.widget import OWWidget, Msg | ||
|
||
|
||
class OWBaseSql(OWWidget, openclass=True): | ||
"""Base widget for connecting to a database. | ||
Override `get_backend` when subclassing to get corresponding backend. | ||
""" | ||
class Outputs: | ||
data = Output("Data", Table) | ||
|
||
class Error(OWWidget.Error): | ||
connection = Msg("{}") | ||
|
||
want_main_area = False | ||
resizing_enabled = False | ||
|
||
host = Setting(None) # type: Optional[str] | ||
port = Setting(None) # type: Optional[str] | ||
database = Setting(None) # type: Optional[str] | ||
schema = Setting(None) # type: Optional[str] | ||
username = "" | ||
password = "" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.backend = None # type: Optional[Backend] | ||
self.data_desc_table = None # type: Optional[Table] | ||
self.database_desc = None # type: Optional[OrderedDict] | ||
self._setup_gui() | ||
self.connect() | ||
|
||
def _setup_gui(self): | ||
self.controlArea.setMinimumWidth(360) | ||
|
||
vbox = gui.vBox(self.controlArea, "Server", addSpace=True) | ||
self.serverbox = gui.vBox(vbox) | ||
self.servertext = QLineEdit(self.serverbox) | ||
self.servertext.setPlaceholderText("Server") | ||
self.servertext.setToolTip("Server") | ||
self.servertext.editingFinished.connect(self._load_credentials) | ||
if self.host: | ||
self.servertext.setText(self.host if not self.port else | ||
"{}:{}".format(self.host, self.port)) | ||
self.serverbox.layout().addWidget(self.servertext) | ||
|
||
self.databasetext = QLineEdit(self.serverbox) | ||
self.databasetext.setPlaceholderText("Database[/Schema]") | ||
self.databasetext.setToolTip("Database or optionally Database/Schema") | ||
if self.database: | ||
self.databasetext.setText( | ||
self.database if not self.schema else | ||
"{}/{}".format(self.database, self.schema)) | ||
self.serverbox.layout().addWidget(self.databasetext) | ||
self.usernametext = QLineEdit(self.serverbox) | ||
self.usernametext.setPlaceholderText("Username") | ||
self.usernametext.setToolTip("Username") | ||
|
||
self.serverbox.layout().addWidget(self.usernametext) | ||
self.passwordtext = QLineEdit(self.serverbox) | ||
self.passwordtext.setPlaceholderText("Password") | ||
self.passwordtext.setToolTip("Password") | ||
self.passwordtext.setEchoMode(QLineEdit.Password) | ||
|
||
self.serverbox.layout().addWidget(self.passwordtext) | ||
|
||
self._load_credentials() | ||
|
||
self.connectbutton = gui.button(self.serverbox, self, "Connect", | ||
callback=self.connect) | ||
self.connectbutton.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) | ||
|
||
def _load_credentials(self): | ||
self._parse_host_port() | ||
cm = self._credential_manager(self.host, self.port) | ||
self.username = cm.username | ||
self.password = cm.password | ||
|
||
if self.username: | ||
self.usernametext.setText(self.username) | ||
if self.password: | ||
self.passwordtext.setText(self.password) | ||
|
||
def _save_credentials(self): | ||
cm = self._credential_manager(self.host, self.port) | ||
cm.username = self.username or "" | ||
cm.password = self.password or "" | ||
|
||
@staticmethod | ||
def _credential_manager(host, port): | ||
return CredentialManager("SQL Table: {}:{}".format(host, port)) | ||
|
||
def _parse_host_port(self): | ||
hostport = self.servertext.text().split(":") | ||
self.host = hostport[0] | ||
self.port = hostport[1] if len(hostport) == 2 else None | ||
|
||
def _check_db_settings(self): | ||
self._parse_host_port() | ||
self.database, _, self.schema = self.databasetext.text().partition("/") | ||
self.username = self.usernametext.text() or None | ||
self.password = self.passwordtext.text() or None | ||
|
||
def connect(self): | ||
self.clear() | ||
self._check_db_settings() | ||
if not self.host or not self.database: | ||
return | ||
try: | ||
backend = self.get_backend() | ||
if backend is None: | ||
return | ||
self.backend = backend(dict( | ||
host=self.host, | ||
port=self.port, | ||
database=self.database, | ||
user=self.username, | ||
password=self.password | ||
)) | ||
self.on_connection_success() | ||
except BackendError as err: | ||
self.on_connection_error(err) | ||
|
||
def get_backend(self) -> Type[Backend]: | ||
""" | ||
Derived widgets should override this to get corresponding backend. | ||
Returns | ||
------- | ||
backend: Type[Backend] | ||
""" | ||
raise NotImplementedError | ||
|
||
def on_connection_success(self): | ||
self._save_credentials() | ||
self.database_desc = OrderedDict(( | ||
("Host", self.host), ("Port", self.port), | ||
("Database", self.database), ("User name", self.username) | ||
)) | ||
|
||
def on_connection_error(self, err): | ||
error = str(err).split("\n")[0] | ||
self.Error.connection(error) | ||
|
||
def open_table(self): | ||
data = self.get_table() | ||
self.data_desc_table = data | ||
self.Outputs.data.send(data) | ||
self.info.set_output_summary(str(len(data))) | ||
|
||
def get_table(self) -> Table: | ||
""" | ||
Derived widgets should override this to get corresponding table. | ||
Returns | ||
------- | ||
table: Table | ||
""" | ||
raise NotImplementedError | ||
|
||
def clear(self): | ||
self.Error.connection.clear() | ||
self.database_desc = None | ||
self.data_desc_table = None | ||
self.Outputs.data.send(None) | ||
self.info.set_output_summary(self.info.NoOutput) | ||
|
||
def send_report(self): | ||
if not self.database_desc: | ||
self.report_paragraph("No database connection.") | ||
return | ||
self.report_items("Database", self.database_desc) | ||
if self.data_desc_table: | ||
self.report_items("Data", | ||
report.describe_data(self.data_desc_table)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# pylint: disable=missing-docstring | ||
import unittest | ||
from unittest.mock import Mock | ||
from collections import OrderedDict | ||
from types import SimpleNamespace | ||
|
||
from Orange.data import Table | ||
from Orange.data.sql.backend import Backend | ||
from Orange.widgets.tests.base import WidgetTest | ||
from Orange.widgets.utils.owbasesql import OWBaseSql | ||
from Orange.data.sql.backend.base import BackendError | ||
|
||
|
||
USERNAME = "UN" | ||
PASSWORD = "PASS" | ||
|
||
|
||
class BrokenBackend(Backend): # pylint: disable=abstract-method | ||
def __init__(self, connection_params): | ||
super().__init__(connection_params) | ||
raise BackendError("Error connecting to DB.") | ||
|
||
|
||
class TestableSqlWidget(OWBaseSql): | ||
name = "SQL" | ||
|
||
def __init__(self): | ||
self.mocked_backend = Mock() | ||
super().__init__() | ||
|
||
def get_backend(self): | ||
return self.mocked_backend | ||
|
||
def get_table(self) -> Table: | ||
return Table("iris") | ||
|
||
@staticmethod | ||
def _credential_manager(_, __): | ||
return SimpleNamespace(username=USERNAME, password=PASSWORD) | ||
|
||
|
||
class TestOWBaseSql(WidgetTest): | ||
def setUp(self): | ||
self.host, self.port, self.db = "host", "port", "DB" | ||
settings = {"host": self.host, "port": self.port, | ||
"database": self.db, "schema": ""} | ||
self.widget = self.create_widget(TestableSqlWidget, | ||
stored_settings=settings) | ||
|
||
def test_connect(self): | ||
self.widget.mocked_backend.assert_called_once_with( | ||
{"host": "host", "port": "port", "database": self.db, | ||
"user": USERNAME, "password": PASSWORD}) | ||
self.assertDictEqual( | ||
self.widget.database_desc, | ||
OrderedDict((("Host", "host"), ("Port", "port"), | ||
("Database", self.db), ("User name", USERNAME)))) | ||
|
||
def test_connection_error(self): | ||
self.widget.get_backend = Mock(return_value=BrokenBackend) | ||
self.widget.connectbutton.click() | ||
self.assertTrue(self.widget.Error.connection.is_shown()) | ||
self.assertIsNone(self.widget.database_desc) | ||
|
||
def test_output(self): | ||
self.widget.open_table() | ||
self.assertIsNotNone(self.get_output(self.widget.Outputs.data)) | ||
self.assertIsNotNone(self.widget.data_desc_table) | ||
|
||
def test_missing_database_parameter(self): | ||
self.widget.open_table() | ||
self.widget.databasetext.setText("") | ||
self.widget.mocked_backend.reset_mock() | ||
self.widget.connectbutton.click() | ||
self.widget.mocked_backend.assert_not_called() | ||
self.assertIsNone(self.get_output(self.widget.Outputs.data)) | ||
self.assertIsNone(self.widget.data_desc_table) | ||
self.assertFalse(self.widget.Error.connection.is_shown()) | ||
|
||
def test_report(self): | ||
self.widget.report_button.click() # DB connection | ||
self.widget.open_table() | ||
self.widget.report_button.click() # table | ||
self.widget.databasetext.setText("") | ||
self.widget.connectbutton.click() | ||
self.widget.report_button.click() # empty | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |