From 26ba3d6289f2d1bc63639d28112dda09b25e2fd9 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Wed, 25 Sep 2019 12:22:24 +0200 Subject: [PATCH] OWBaseSql: Base widget for connecting to DB --- Orange/widgets/utils/owbasesql.py | 187 +++++++++++++++++++ Orange/widgets/utils/tests/test_owbasesql.py | 96 ++++++++++ 2 files changed, 283 insertions(+) create mode 100644 Orange/widgets/utils/owbasesql.py create mode 100644 Orange/widgets/utils/tests/test_owbasesql.py diff --git a/Orange/widgets/utils/owbasesql.py b/Orange/widgets/utils/owbasesql.py new file mode 100644 index 00000000000..e0c3300e766 --- /dev/null +++ b/Orange/widgets/utils/owbasesql.py @@ -0,0 +1,187 @@ +from typing import Type +from collections import OrderedDict + +from AnyQt.QtWidgets import QLineEdit, QSizePolicy + +from Orange.data import Table +from Orange.data.sql.backend import Backend +from Orange.data.sql.backend.base import BackendError +from Orange.widgets import gui, report +from Orange.widgets.credentials import CredentialManager +from Orange.widgets.settings import Setting +from Orange.widgets.utils.signals import Output +from Orange.widgets.widget import OWWidget, Msg + + +class OWBaseSql(OWWidget, openclass=True): + """Base widget for connecting to a database. + Override `get_backend` when subclassing to get corresponding backend. + """ + class Outputs: + data = Output("Data", Table) + + class Error(OWWidget.Error): + connection = Msg("{}") + + want_main_area = False + resizing_enabled = False + + host = Setting(None) # type: Optional[str] + port = Setting(None) # type: Optional[str] + database = Setting(None) # type: Optional[str] + schema = Setting(None) # type: Optional[str] + username = "" + password = "" + + def __init__(self): + super().__init__() + self.backend = None # type: Optional[Backend] + self.data_desc_table = None # type: Optional[Table] + self.database_desc = None # type: Optional[OrderedDict] + self._setup_gui() + self.connect() + + def _setup_gui(self): + self.controlArea.setMinimumWidth(360) + + vbox = gui.vBox(self.controlArea, "Server", addSpace=True) + self.serverbox = gui.vBox(vbox) + self.servertext = QLineEdit(self.serverbox) + self.servertext.setPlaceholderText("Server") + self.servertext.setToolTip("Server") + self.servertext.editingFinished.connect(self._load_credentials) + if self.host: + self.servertext.setText(self.host if not self.port else + "{}:{}".format(self.host, self.port)) + self.serverbox.layout().addWidget(self.servertext) + + self.databasetext = QLineEdit(self.serverbox) + self.databasetext.setPlaceholderText("Database[/Schema]") + self.databasetext.setToolTip("Database or optionally Database/Schema") + if self.database: + self.databasetext.setText( + self.database if not self.schema else + "{}/{}".format(self.database, self.schema)) + self.serverbox.layout().addWidget(self.databasetext) + self.usernametext = QLineEdit(self.serverbox) + self.usernametext.setPlaceholderText("Username") + self.usernametext.setToolTip("Username") + + self.serverbox.layout().addWidget(self.usernametext) + self.passwordtext = QLineEdit(self.serverbox) + self.passwordtext.setPlaceholderText("Password") + self.passwordtext.setToolTip("Password") + self.passwordtext.setEchoMode(QLineEdit.Password) + + self.serverbox.layout().addWidget(self.passwordtext) + + self._load_credentials() + + self.connectbutton = gui.button(self.serverbox, self, "Connect", + callback=self.connect) + self.connectbutton.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + def _load_credentials(self): + self._parse_host_port() + cm = self._credential_manager(self.host, self.port) + self.username = cm.username + self.password = cm.password + + if self.username: + self.usernametext.setText(self.username) + if self.password: + self.passwordtext.setText(self.password) + + def _save_credentials(self): + cm = self._credential_manager(self.host, self.port) + cm.username = self.username or "" + cm.password = self.password or "" + + @staticmethod + def _credential_manager(host, port): + return CredentialManager("SQL Table: {}:{}".format(host, port)) + + def _parse_host_port(self): + hostport = self.servertext.text().split(":") + self.host = hostport[0] + self.port = hostport[1] if len(hostport) == 2 else None + + def _check_db_settings(self): + self._parse_host_port() + self.database, _, self.schema = self.databasetext.text().partition("/") + self.username = self.usernametext.text() or None + self.password = self.passwordtext.text() or None + + def connect(self): + self.clear() + self._check_db_settings() + if not self.host or not self.database: + return + try: + backend = self.get_backend() + if backend is None: + return + self.backend = backend(dict( + host=self.host, + port=self.port, + database=self.database, + user=self.username, + password=self.password + )) + self.on_connection_success() + except BackendError as err: + self.on_connection_error(err) + + def get_backend(self) -> Type[Backend]: + """ + Derived widgets should override this to get corresponding backend. + + Returns + ------- + backend: Type[Backend] + """ + raise NotImplementedError + + def on_connection_success(self): + self._save_credentials() + self.database_desc = OrderedDict(( + ("Host", self.host), ("Port", self.port), + ("Database", self.database), ("User name", self.username) + )) + + def on_connection_error(self, err): + error = str(err).split("\n")[0] + self.Error.connection(error) + + def open_table(self): + data = self.get_table() + self.data_desc_table = data + self.Outputs.data.send(data) + info = str(len(data)) if data else self.info.NoOutput + self.info.set_output_summary(info) + + def get_table(self) -> Table: + """ + Derived widgets should override this to get corresponding table. + + Returns + ------- + table: Table + """ + raise NotImplementedError + + def clear(self): + self.Error.connection.clear() + self.database_desc = None + self.data_desc_table = None + self.Outputs.data.send(None) + self.info.set_output_summary(self.info.NoOutput) + + def send_report(self): + if not self.database_desc: + self.report_paragraph("No database connection.") + return + self.report_items("Database", self.database_desc) + if self.data_desc_table: + self.report_items("Data", + report.describe_data(self.data_desc_table)) diff --git a/Orange/widgets/utils/tests/test_owbasesql.py b/Orange/widgets/utils/tests/test_owbasesql.py new file mode 100644 index 00000000000..99bd0371d6d --- /dev/null +++ b/Orange/widgets/utils/tests/test_owbasesql.py @@ -0,0 +1,96 @@ +# pylint: disable=missing-docstring +import unittest +from unittest.mock import Mock +from collections import OrderedDict +from types import SimpleNamespace + +from Orange.data import Table +from Orange.data.sql.backend import Backend +from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.utils.owbasesql import OWBaseSql +from Orange.data.sql.backend.base import BackendError + + +USERNAME = "UN" +PASSWORD = "PASS" + + +class BrokenBackend(Backend): # pylint: disable=abstract-method + def __init__(self, connection_params): + super().__init__(connection_params) + raise BackendError("Error connecting to DB.") + + +class TestableSqlWidget(OWBaseSql): + name = "SQL" + + def __init__(self): + self.mocked_backend = Mock() + super().__init__() + + def get_backend(self): + return self.mocked_backend + + def get_table(self) -> Table: + return Table("iris") + + @staticmethod + def _credential_manager(_, __): # pylint: disable=arguments-differ + return SimpleNamespace(username=USERNAME, password=PASSWORD) + + +class TestOWBaseSql(WidgetTest): + def setUp(self): + self.host, self.port, self.db = "host", "port", "DB" + settings = {"host": self.host, "port": self.port, + "database": self.db, "schema": ""} + self.widget = self.create_widget(TestableSqlWidget, + stored_settings=settings) + + def test_connect(self): + self.widget.mocked_backend.assert_called_once_with( + {"host": "host", "port": "port", "database": self.db, + "user": USERNAME, "password": PASSWORD}) + self.assertDictEqual( + self.widget.database_desc, + OrderedDict((("Host", "host"), ("Port", "port"), + ("Database", self.db), ("User name", USERNAME)))) + + def test_connection_error(self): + self.widget.get_backend = Mock(return_value=BrokenBackend) + self.widget.connectbutton.click() + self.assertTrue(self.widget.Error.connection.is_shown()) + self.assertIsNone(self.widget.database_desc) + + def test_output(self): + self.widget.open_table() + self.assertIsNotNone(self.get_output(self.widget.Outputs.data)) + self.assertIsNotNone(self.widget.data_desc_table) + + def test_output_error(self): + self.widget.get_table = lambda: None + self.widget.open_table() + self.assertIsNone(self.get_output(self.widget.Outputs.data)) + self.assertIsNone(self.widget.data_desc_table) + + def test_missing_database_parameter(self): + self.widget.open_table() + self.widget.databasetext.setText("") + self.widget.mocked_backend.reset_mock() + self.widget.connectbutton.click() + self.widget.mocked_backend.assert_not_called() + self.assertIsNone(self.get_output(self.widget.Outputs.data)) + self.assertIsNone(self.widget.data_desc_table) + self.assertFalse(self.widget.Error.connection.is_shown()) + + def test_report(self): + self.widget.report_button.click() # DB connection + self.widget.open_table() + self.widget.report_button.click() # table + self.widget.databasetext.setText("") + self.widget.connectbutton.click() + self.widget.report_button.click() # empty + + +if __name__ == "__main__": + unittest.main()