Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Strict Config Parsing #807

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,13 +568,18 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
info!("Reloading config");
let mut res = BytesMut::new();

reload_config(client_server_map).await?;
// TODO If error print what the error was
match reload_config(client_server_map).await {
Ok(_) => (),
Err(_) => {
res.put(notify("ERROR", "Error found in config, please see pgcat log for details".to_string()));
},
};

get_config().show();

let mut res = BytesMut::new();

res.put(command_complete("RELOAD"));

// ReadyForQuery
Expand Down
11 changes: 11 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl Address {

/// PostgreSQL user.
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct User {
pub username: String,
pub password: Option<String>,
Expand Down Expand Up @@ -256,6 +257,7 @@ impl User {

/// General configuration.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct General {
#[serde(default = "General::default_host")]
pub host: String,
Expand Down Expand Up @@ -506,6 +508,7 @@ impl std::fmt::Display for LoadBalancingMode {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
#[serde(deny_unknown_fields)]
pub struct Pool {
#[serde(default = "Pool::default_pool_mode")]
pub pool_mode: PoolMode,
Expand Down Expand Up @@ -795,6 +798,7 @@ pub struct MirrorServerConfig {

/// Shard configuration.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct Shard {
pub database: String,
pub mirrors: Option<Vec<MirrorServerConfig>>,
Expand Down Expand Up @@ -854,6 +858,7 @@ impl Default for Shard {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
Expand Down Expand Up @@ -886,6 +891,7 @@ impl std::fmt::Display for Plugins {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
Expand All @@ -898,6 +904,7 @@ impl Plugin for Intercept {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
Expand All @@ -910,6 +917,7 @@ impl Plugin for TableAccess {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct QueryLogger {
pub enabled: bool,
}
Expand All @@ -921,6 +929,7 @@ impl Plugin for QueryLogger {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
Expand All @@ -942,6 +951,7 @@ impl Intercept {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[serde(deny_unknown_fields)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
Expand All @@ -961,6 +971,7 @@ impl Query {

/// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct Config {
// Serializer maintains the order of fields in the struct
// so we should always put simple fields before nested fields
Expand Down
Empty file removed tests/python/conftest.py
Empty file.
1 change: 1 addition & 0 deletions tests/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
jinja2
psycopg2==2.9.3
psutil==5.9.1
120 changes: 120 additions & 0 deletions tests/python/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@

import pytest
import subprocess
import tempfile
import time
import utils
from jinja2 import Environment, BaseLoader

template_config = """
{% if invalid_config_type == 'global' %}
[non_existant_section]
non_existant_parameter = "arbitrary_value"
{% endif %}

{% if invalid_config_type == 'plugins' %}
[plugins.nonexistent]
enabled = true
{% endif %}

[plugins.prewarmer]
enabled = false
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
{% if invalid_config_type == 'plugins_prewarmer' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[plugins.query_logger]
enabled = false
{% if invalid_config_type == 'plugins_query_logger' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
{% if invalid_config_type == 'plugins_table_access' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[plugins.intercept]
enabled = true
{% if invalid_config_type == 'plugins_intercept' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[plugins.intercept.queries.0]

query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
{% if invalid_config_type == 'plugins_intercept_queries' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[general]
host = "0.0.0.0"
port = 6433
admin_username = "pgcat"
admin_password = "pgcat"
{% if invalid_config_type == 'general' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[pools.pgml]
{% if invalid_config_type == 'pools' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[pools.pgml.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
{% if invalid_config_type == 'user' %}
non_existant_parameter = "arbitrary_value"
{% endif %}

[pools.pgml.shards.0]
servers = [
["127.0.0.1", 5432, "primary"]
]
database = "some_db"
{% if invalid_config_type == 'pool' %}
non_existant_parameter = "arbitrary_value"
{% endif %}
"""

parameters = [
'global', 'general', 'user', 'pool', 'plugins', 'plugins_prewarmer', 'plugins_query_logger', 'plugins_table_access', 'plugins_intercept', 'plugins_intercept_queries', 'pools']
@pytest.mark.parametrize("invalid_config_type", parameters)
def test_negative(invalid_config_type: str):
rtemplate = Environment(loader=BaseLoader).from_string(template_config)
data = rtemplate.render(invalid_config_type=invalid_config_type)

print(data)
tmp = tempfile.NamedTemporaryFile()
with open(tmp.name, 'w') as f:
f.write(data)

process = subprocess.Popen(["./target/debug/pgcat", tmp.name], shell=False)
time.sleep(3)
poll = process.poll()
try:
assert poll is not None
except AssertionError as e:
process.kill()
process.wait()
raise