diff --git a/src/admin.rs b/src/admin.rs index f08ef2e1..c438f66a 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -568,13 +568,18 @@ where T: tokio::io::AsyncWrite + std::marker::Unpin, { info!("Reloading config"); + let mut res = BytesMut::new(); - reload_config(client_server_map).await?; + // TODO If error print what the error was + match reload_config(client_server_map).await { + Ok(_) => (), + Err(_) => { + res.put(notify("ERROR", "Error found in config, please see pgcat log for details".to_string())); + }, + }; get_config().show(); - let mut res = BytesMut::new(); - res.put(command_complete("RELOAD")); // ReadyForQuery diff --git a/src/config.rs b/src/config.rs index c7aaf4c3..15a06f44 100644 --- a/src/config.rs +++ b/src/config.rs @@ -205,6 +205,7 @@ impl Address { /// PostgreSQL user. #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] pub struct User { pub username: String, pub password: Option, @@ -256,6 +257,7 @@ impl User { /// General configuration. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(deny_unknown_fields)] pub struct General { #[serde(default = "General::default_host")] pub host: String, @@ -506,6 +508,7 @@ impl std::fmt::Display for LoadBalancingMode { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[serde(deny_unknown_fields)] pub struct Pool { #[serde(default = "Pool::default_pool_mode")] pub pool_mode: PoolMode, @@ -795,6 +798,7 @@ pub struct MirrorServerConfig { /// Shard configuration. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct Shard { pub database: String, pub mirrors: Option>, @@ -854,6 +858,7 @@ impl Default for Shard { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct Plugins { pub intercept: Option, pub table_access: Option, @@ -886,6 +891,7 @@ impl std::fmt::Display for Plugins { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct Intercept { pub enabled: bool, pub queries: BTreeMap, @@ -898,6 +904,7 @@ impl Plugin for Intercept { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct TableAccess { pub enabled: bool, pub tables: Vec, @@ -910,6 +917,7 @@ impl Plugin for TableAccess { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct QueryLogger { pub enabled: bool, } @@ -921,6 +929,7 @@ impl Plugin for QueryLogger { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct Prewarmer { pub enabled: bool, pub queries: Vec, @@ -942,6 +951,7 @@ impl Intercept { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +#[serde(deny_unknown_fields)] pub struct Query { pub query: String, pub schema: Vec>, @@ -961,6 +971,7 @@ impl Query { /// Configuration wrapper. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(deny_unknown_fields)] pub struct Config { // Serializer maintains the order of fields in the struct // so we should always put simple fields before nested fields diff --git a/tests/python/conftest.py b/tests/python/conftest.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index ec8b08f8..f90abf3f 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -1,3 +1,4 @@ pytest +jinja2 psycopg2==2.9.3 psutil==5.9.1 diff --git a/tests/python/test_config.py b/tests/python/test_config.py new file mode 100644 index 00000000..8c736802 --- /dev/null +++ b/tests/python/test_config.py @@ -0,0 +1,120 @@ + +import pytest +import subprocess +import tempfile +import time +import utils +from jinja2 import Environment, BaseLoader + +template_config = """ +{% if invalid_config_type == 'global' %} +[non_existant_section] +non_existant_parameter = "arbitrary_value" +{% endif %} + +{% if invalid_config_type == 'plugins' %} +[plugins.nonexistent] +enabled = true +{% endif %} + +[plugins.prewarmer] +enabled = false +queries = [ + "SELECT pg_prewarm('pgbench_accounts')", +] +{% if invalid_config_type == 'plugins_prewarmer' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[plugins.query_logger] +enabled = false +{% if invalid_config_type == 'plugins_query_logger' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[plugins.table_access] +enabled = false +tables = [ + "pg_user", + "pg_roles", + "pg_database", +] +{% if invalid_config_type == 'plugins_table_access' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[plugins.intercept] +enabled = true +{% if invalid_config_type == 'plugins_intercept' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[plugins.intercept.queries.0] + +query = "select current_database() as a, current_schemas(false) as b" +schema = [ + ["a", "text"], + ["b", "text"], +] +result = [ + ["${DATABASE}", "{public}"], +] +{% if invalid_config_type == 'plugins_intercept_queries' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[general] +host = "0.0.0.0" +port = 6433 +admin_username = "pgcat" +admin_password = "pgcat" +{% if invalid_config_type == 'general' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[pools.pgml] +{% if invalid_config_type == 'pools' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[pools.pgml.users.0] +username = "simple_user" +password = "simple_user" +pool_size = 10 +min_pool_size = 1 +pool_mode = "transaction" +{% if invalid_config_type == 'user' %} +non_existant_parameter = "arbitrary_value" +{% endif %} + +[pools.pgml.shards.0] +servers = [ + ["127.0.0.1", 5432, "primary"] +] +database = "some_db" +{% if invalid_config_type == 'pool' %} +non_existant_parameter = "arbitrary_value" +{% endif %} +""" + +parameters = [ + 'global', 'general', 'user', 'pool', 'plugins', 'plugins_prewarmer', 'plugins_query_logger', 'plugins_table_access', 'plugins_intercept', 'plugins_intercept_queries', 'pools'] +@pytest.mark.parametrize("invalid_config_type", parameters) +def test_negative(invalid_config_type: str): + rtemplate = Environment(loader=BaseLoader).from_string(template_config) + data = rtemplate.render(invalid_config_type=invalid_config_type) + + print(data) + tmp = tempfile.NamedTemporaryFile() + with open(tmp.name, 'w') as f: + f.write(data) + + process = subprocess.Popen(["./target/debug/pgcat", tmp.name], shell=False) + time.sleep(3) + poll = process.poll() + try: + assert poll is not None + except AssertionError as e: + process.kill() + process.wait() + raise