diff --git a/agent/agents/mongodb/internal/profiler/profiler_test.go b/agent/agents/mongodb/internal/profiler/profiler_test.go index 304d069ada..3d36d276dd 100644 --- a/agent/agents/mongodb/internal/profiler/profiler_test.go +++ b/agent/agents/mongodb/internal/profiler/profiler_test.go @@ -222,7 +222,9 @@ func testProfiler(t *testing.T, url string) { Query: findBucket.Common.Example, } - ex := actions.NewMongoDBExplainAction(id, 5*time.Second, params, os.TempDir()) + ex, err := actions.NewMongoDBExplainAction(id, 5*time.Second, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), ex.Timeout()) defer cancel() res, err := ex.Run(ctx) diff --git a/agent/client/client.go b/agent/client/client.go index b79eb8e261..8abc6b25af 100644 --- a/agent/client/client.go +++ b/agent/client/client.go @@ -454,9 +454,10 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { cfg := c.cfg.Get() var action actions.Action + var err error switch params := p.Params.(type) { case *agentpb.StartActionRequest_MysqlExplainParams: - action = actions.NewMySQLExplainAction(p.ActionId, timeout, params.MysqlExplainParams) + action, err = actions.NewMySQLExplainAction(p.ActionId, timeout, params.MysqlExplainParams) case *agentpb.StartActionRequest_MysqlShowCreateTableParams: action = actions.NewMySQLShowCreateTableAction(p.ActionId, timeout, params.MysqlShowCreateTableParams) @@ -468,13 +469,13 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { action = actions.NewMySQLShowIndexAction(p.ActionId, timeout, params.MysqlShowIndexParams) case *agentpb.StartActionRequest_PostgresqlShowCreateTableParams: - action = actions.NewPostgreSQLShowCreateTableAction(p.ActionId, timeout, params.PostgresqlShowCreateTableParams, cfg.Paths.TempDir) + action, err = actions.NewPostgreSQLShowCreateTableAction(p.ActionId, timeout, params.PostgresqlShowCreateTableParams, cfg.Paths.TempDir) case *agentpb.StartActionRequest_PostgresqlShowIndexParams: - action = actions.NewPostgreSQLShowIndexAction(p.ActionId, timeout, params.PostgresqlShowIndexParams, cfg.Paths.TempDir) + action, err = actions.NewPostgreSQLShowIndexAction(p.ActionId, timeout, params.PostgresqlShowIndexParams, cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbExplainParams: - action = actions.NewMongoDBExplainAction(p.ActionId, timeout, params.MongodbExplainParams, cfg.Paths.TempDir) + action, err = actions.NewMongoDBExplainAction(p.ActionId, timeout, params.MongodbExplainParams, cfg.Paths.TempDir) case *agentpb.StartActionRequest_MysqlQueryShowParams: action = actions.NewMySQLQueryShowAction(p.ActionId, timeout, params.MysqlQueryShowParams) @@ -483,13 +484,13 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { action = actions.NewMySQLQuerySelectAction(p.ActionId, timeout, params.MysqlQuerySelectParams) case *agentpb.StartActionRequest_PostgresqlQueryShowParams: - action = actions.NewPostgreSQLQueryShowAction(p.ActionId, timeout, params.PostgresqlQueryShowParams, cfg.Paths.TempDir) + action, err = actions.NewPostgreSQLQueryShowAction(p.ActionId, timeout, params.PostgresqlQueryShowParams, cfg.Paths.TempDir) case *agentpb.StartActionRequest_PostgresqlQuerySelectParams: - action = actions.NewPostgreSQLQuerySelectAction(p.ActionId, timeout, params.PostgresqlQuerySelectParams, cfg.Paths.TempDir) + action, err = actions.NewPostgreSQLQuerySelectAction(p.ActionId, timeout, params.PostgresqlQuerySelectParams, cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbQueryGetparameterParams: - action = actions.NewMongoDBQueryAdmincommandAction( + action, err = actions.NewMongoDBQueryAdmincommandAction( p.ActionId, timeout, params.MongodbQueryGetparameterParams.Dsn, @@ -499,7 +500,7 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbQueryBuildinfoParams: - action = actions.NewMongoDBQueryAdmincommandAction( + action, err = actions.NewMongoDBQueryAdmincommandAction( p.ActionId, timeout, params.MongodbQueryBuildinfoParams.Dsn, @@ -509,7 +510,7 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbQueryGetcmdlineoptsParams: - action = actions.NewMongoDBQueryAdmincommandAction( + action, err = actions.NewMongoDBQueryAdmincommandAction( p.ActionId, timeout, params.MongodbQueryGetcmdlineoptsParams.Dsn, @@ -519,7 +520,7 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbQueryReplsetgetstatusParams: - action = actions.NewMongoDBQueryAdmincommandAction( + action, err = actions.NewMongoDBQueryAdmincommandAction( p.ActionId, timeout, params.MongodbQueryReplsetgetstatusParams.Dsn, @@ -529,7 +530,7 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { cfg.Paths.TempDir) case *agentpb.StartActionRequest_MongodbQueryGetdiagnosticdataParams: - action = actions.NewMongoDBQueryAdmincommandAction( + action, err = actions.NewMongoDBQueryAdmincommandAction( p.ActionId, timeout, params.MongodbQueryGetdiagnosticdataParams.Dsn, @@ -565,6 +566,10 @@ func (c *Client) handleStartActionRequest(p *agentpb.StartActionRequest) error { return errors.Wrapf(agenterrors.ErrInvalidArgument, "invalid action type request: %T", params) } + if err != nil { + return errors.Wrap(err, "failed to create action") + } + return c.runner.StartAction(action) } @@ -645,7 +650,7 @@ func (c *Client) handleStartJobRequest(p *agentpb.StartJobRequest) error { return errors.WithStack(err) } - job, err = jobs.NewMongoDBBackupJob(p.JobId, timeout, j.MongodbBackup.Name, &dsn, locationConfig, + job, err = jobs.NewMongoDBBackupJob(p.JobId, timeout, j.MongodbBackup.Name, dsn, locationConfig, j.MongodbBackup.EnablePitr, j.MongodbBackup.DataModel, j.MongodbBackup.Folder) if err != nil { return err @@ -678,7 +683,7 @@ func (c *Client) handleStartJobRequest(p *agentpb.StartJobRequest) error { } job = jobs.NewMongoDBRestoreJob(p.JobId, timeout, j.MongodbRestoreBackup.Name, - j.MongodbRestoreBackup.PitrTimestamp.AsTime(), &dsn, locationConfig, + j.MongodbRestoreBackup.PitrTimestamp.AsTime(), dsn, locationConfig, c.supervisor, j.MongodbRestoreBackup.Folder, j.MongodbRestoreBackup.PbmMetadata.Name) default: return errors.Errorf("unknown job type: %T", j) diff --git a/agent/client/client_test.go b/agent/client/client_test.go index ce987348f8..64fdfa8abf 100644 --- a/agent/client/client_test.go +++ b/agent/client/client_test.go @@ -163,7 +163,7 @@ func TestClient(t *testing.T) { s.On("AgentsList").Return([]*agentlocalpb.AgentInfo{}) s.On("ClearChangesChannel").Return() - r := runner.New(cfgStorage.Get().RunnerCapacity) + r := runner.New(cfgStorage.Get().RunnerCapacity, cfgStorage.Get().RunnerMaxConnectionsPerService) client := New(cfgStorage, &s, r, nil, nil, nil, connectionuptime.NewService(time.Hour), nil) err := client.Run(context.Background()) assert.NoError(t, err) @@ -281,7 +281,7 @@ func TestUnexpectedActionType(t *testing.T) { s.On("AgentsList").Return([]*agentlocalpb.AgentInfo{}) s.On("ClearChangesChannel").Return() - r := runner.New(cfgStorage.Get().RunnerCapacity) + r := runner.New(cfgStorage.Get().RunnerCapacity, cfgStorage.Get().RunnerMaxConnectionsPerService) client := New(cfgStorage, s, r, nil, nil, nil, connectionuptime.NewService(time.Hour), nil) err := client.Run(context.Background()) assert.NoError(t, err) diff --git a/agent/commands/run.go b/agent/commands/run.go index c62fefc664..a2c11d5517 100644 --- a/agent/commands/run.go +++ b/agent/commands/run.go @@ -71,7 +71,7 @@ func Run() { supervisor := supervisor.NewSupervisor(ctx, v, configStorage) connectionChecker := connectionchecker.New(configStorage) serviceInfoBroker := serviceinfobroker.New(configStorage) - r := runner.New(cfg.RunnerCapacity) + r := runner.New(cfg.RunnerCapacity, cfg.RunnerMaxConnectionsPerService) client := client.New(configStorage, supervisor, r, connectionChecker, v, serviceInfoBroker, prepareConnectionService(ctx, cfg), logStore) localServer := agentlocal.NewServer(configStorage, supervisor, client, configFilepath, logStore) diff --git a/agent/config/config.go b/agent/config/config.go index c2ef25e624..8809800733 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -146,10 +146,11 @@ type Setup struct { type Config struct { //nolint:musttag // no config file there - ID string `yaml:"id"` - ListenAddress string `yaml:"listen-address"` - ListenPort uint16 `yaml:"listen-port"` - RunnerCapacity uint16 `yaml:"runner-capacity,omitempty"` + ID string `yaml:"id"` + ListenAddress string `yaml:"listen-address"` + ListenPort uint16 `yaml:"listen-port"` + RunnerCapacity uint16 `yaml:"runner-capacity,omitempty"` + RunnerMaxConnectionsPerService uint16 `yaml:"runner-max-connections-per-service,omitempty"` Server Server `yaml:"server"` Paths Paths `yaml:"paths"` @@ -352,6 +353,8 @@ func Application(cfg *Config) (*kingpin.Application, *string) { Envar("PMM_AGENT_LISTEN_PORT").Uint16Var(&cfg.ListenPort) app.Flag("runner-capacity", "Agent internal actions/jobs runner capacity [PMM_AGENT_RUNNER_CAPACITY]"). Envar("PMM_AGENT_RUNNER_CAPACITY").Uint16Var(&cfg.RunnerCapacity) + app.Flag("runner-max-connections-per-service", "Agent internal action/job runner connection limit per DB instance"). + Envar("PMM_AGENT_RUNNER_MAX_CONNECTIONS_PER_SERVICE").Uint16Var(&cfg.RunnerMaxConnectionsPerService) app.Flag("server-address", "PMM Server address [PMM_AGENT_SERVER_ADDRESS]"). Envar("PMM_AGENT_SERVER_ADDRESS").PlaceHolder("").StringVar(&cfg.Server.Address) diff --git a/agent/runner/actions/action.go b/agent/runner/actions/action.go index 5a9625f9c8..e8dd57d5d3 100644 --- a/agent/runner/actions/action.go +++ b/agent/runner/actions/action.go @@ -29,6 +29,8 @@ type Action interface { Type() string // Timeout returns Job timeout. Timeout() time.Duration + // DSN returns Data Source Name required for the Action. + DSN() string // Run runs an Action and returns output and error. Run(ctx context.Context) ([]byte, error) diff --git a/agent/runner/actions/mongodb_explain_action.go b/agent/runner/actions/mongodb_explain_action.go index 944572ee59..373b04df9f 100644 --- a/agent/runner/actions/mongodb_explain_action.go +++ b/agent/runner/actions/mongodb_explain_action.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "path/filepath" - "strings" "time" "github.com/percona/percona-toolkit/src/go/mongolib/proto" @@ -31,23 +30,30 @@ import ( "github.com/percona/pmm/api/agentpb" ) +const mongoDBExplainActionType = "mongodb-explain" + type mongodbExplainAction struct { id string timeout time.Duration params *agentpb.StartActionRequest_MongoDBExplainParams - tempDir string + dsn string } var errCannotExplain = fmt.Errorf("cannot explain this type of query") // NewMongoDBExplainAction creates a MongoDB EXPLAIN query Action. -func NewMongoDBExplainAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_MongoDBExplainParams, tempDir string) Action { +func NewMongoDBExplainAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_MongoDBExplainParams, tempDir string) (Action, error) { + dsn, err := templates.RenderDSN(params.Dsn, params.TextFiles, filepath.Join(tempDir, mongoDBExplainActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &mongodbExplainAction{ id: id, timeout: timeout, params: params, - tempDir: tempDir, - } + dsn: dsn, + }, nil } // ID returns an Action ID. @@ -62,17 +68,17 @@ func (a *mongodbExplainAction) Timeout() time.Duration { // Type returns an Action type. func (a *mongodbExplainAction) Type() string { - return "mongodb-explain" + return mongoDBExplainActionType +} + +// DSN returns the DSN for the Action. +func (a *mongodbExplainAction) DSN() string { + return a.dsn } // Run runs an action and returns output and error. func (a *mongodbExplainAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.params.Dsn, a.params.TextFiles, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - opts, err := mongo_fix.ClientOptionsForDSN(dsn) + opts, err := mongo_fix.ClientOptionsForDSN(a.dsn) if err != nil { return nil, errors.WithStack(err) } diff --git a/agent/runner/actions/mongodb_explain_action_test.go b/agent/runner/actions/mongodb_explain_action_test.go index 1ac5fe2764..8e927b71ec 100644 --- a/agent/runner/actions/mongodb_explain_action_test.go +++ b/agent/runner/actions/mongodb_explain_action_test.go @@ -52,7 +52,9 @@ func TestMongoDBExplain(t *testing.T) { Query: `{"ns":"test.coll","op":"query","query":{"k":{"$lte":{"$numberInt":"1"}}}}`, } - ex := NewMongoDBExplainAction(id, 0, params, os.TempDir()) + ex, err := NewMongoDBExplainAction(id, 0, params, os.TempDir()) + require.NoError(t, err) + res, err := ex.Run(ctx) assert.Nil(t, err) @@ -130,7 +132,9 @@ func TestNewMongoDBExplain(t *testing.T) { Query: string(query), } - ex := NewMongoDBExplainAction(id, 0, params, os.TempDir()) + ex, err := NewMongoDBExplainAction(id, 0, params, os.TempDir()) + require.NoError(t, err) + res, err := ex.Run(ctx) assert.NoError(t, err) diff --git a/agent/runner/actions/mongodb_query_admincommand_action.go b/agent/runner/actions/mongodb_query_admincommand_action.go index 2ded20cf5e..b490069542 100644 --- a/agent/runner/actions/mongodb_query_admincommand_action.go +++ b/agent/runner/actions/mongodb_query_admincommand_action.go @@ -17,7 +17,6 @@ package actions import ( "context" "path/filepath" - "strings" "time" "github.com/pkg/errors" @@ -29,27 +28,38 @@ import ( "github.com/percona/pmm/api/agentpb" ) +const mongoDBQueryAdminCommandActionType = "mongodb-query-admincommand" + type mongodbQueryAdmincommandAction struct { id string timeout time.Duration dsn string - files *agentpb.TextFiles command string arg interface{} - tempDir string } // NewMongoDBQueryAdmincommandAction creates a MongoDB adminCommand query action. -func NewMongoDBQueryAdmincommandAction(id string, timeout time.Duration, dsn string, files *agentpb.TextFiles, command string, arg interface{}, tempDir string) Action { +func NewMongoDBQueryAdmincommandAction( + id string, + timeout time.Duration, + dsn string, + files *agentpb.TextFiles, + command string, + arg interface{}, + tempDir string, +) (Action, error) { + dsn, err := templates.RenderDSN(dsn, files, filepath.Join(tempDir, mongoDBQueryAdminCommandActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &mongodbQueryAdmincommandAction{ id: id, timeout: timeout, dsn: dsn, - files: files, command: command, arg: arg, - tempDir: tempDir, - } + }, nil } // ID returns an action ID. @@ -64,17 +74,17 @@ func (a *mongodbQueryAdmincommandAction) Timeout() time.Duration { // Type returns an action type. func (a *mongodbQueryAdmincommandAction) Type() string { - return "mongodb-query-admincommand" + return mongoDBQueryAdminCommandActionType +} + +// DSN returns a DSN for the Action. +func (a *mongodbQueryAdmincommandAction) DSN() string { + return a.dsn } // Run runs an action and returns output and error. func (a *mongodbQueryAdmincommandAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.dsn, a.files, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - opts, err := mongo_fix.ClientOptionsForDSN(dsn) + opts, err := mongo_fix.ClientOptionsForDSN(a.dsn) if err != nil { return nil, errors.WithStack(err) } diff --git a/agent/runner/actions/mongodb_query_admincommand_action_test.go b/agent/runner/actions/mongodb_query_admincommand_action_test.go index 6318594b5a..ed03fa1cce 100644 --- a/agent/runner/actions/mongodb_query_admincommand_action_test.go +++ b/agent/runner/actions/mongodb_query_admincommand_action_test.go @@ -175,7 +175,9 @@ func TestMongoDBActionsReplWithSSL(t *testing.T) { func runAction(t *testing.T, id string, timeout time.Duration, dsn string, files *agentpb.TextFiles, command string, arg interface{}, tempDir string) []byte { //nolint:unparam t.Helper() - a := NewMongoDBQueryAdmincommandAction(id, timeout, dsn, files, command, arg, tempDir) + a, err := NewMongoDBQueryAdmincommandAction(id, timeout, dsn, files, command, arg, tempDir) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() b, err := a.Run(ctx) @@ -227,7 +229,9 @@ func replSetGetStatusAssertionsReplicated(t *testing.T, b []byte) { //nolint:the } func replSetGetStatusAssertionsStandalone(t *testing.T, id string, timeout time.Duration, dsn string, files *agentpb.TextFiles, command string, arg interface{}, tempDir string) { //nolint:thelper - a := NewMongoDBQueryAdmincommandAction(id, timeout, dsn, files, command, arg, tempDir) + a, err := NewMongoDBQueryAdmincommandAction(id, timeout, dsn, files, command, arg, tempDir) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() b, err := a.Run(ctx) diff --git a/agent/runner/actions/mysql_explain_action.go b/agent/runner/actions/mysql_explain_action.go index 6eb29cb795..7fda097e5f 100644 --- a/agent/runner/actions/mysql_explain_action.go +++ b/agent/runner/actions/mysql_explain_action.go @@ -54,12 +54,27 @@ var errCannotEncodeExplainResponse = errors.New("cannot JSON encode the explain // NewMySQLExplainAction creates MySQL Explain Action. // This is an Action that can run `EXPLAIN` command on MySQL service with given DSN. -func NewMySQLExplainAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_MySQLExplainParams) Action { +func NewMySQLExplainAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_MySQLExplainParams) (Action, error) { + if params.Query == "" { + return nil, errors.New("Query to EXPLAIN is empty") + } + + // You cant run Explain on trimmed queries. + if strings.HasSuffix(params.Query, "...") { + return nil, errors.New("EXPLAIN failed because the query exceeded max length and got trimmed. Set max-query-length to a larger value.") //nolint:revive + } + + // Explain is supported only for DML queries. + // https://dev.mysql.com/doc/refman/8.0/en/using-explain.html + if !isDMLQuery(params.Query) { + return nil, errors.New("EXPLAIN functionality is supported only for DML queries - SELECT, INSERT, UPDATE, DELETE and REPLACE.") //nolint:revive + } + return &mysqlExplainAction{ id: id, timeout: timeout, params: params, - } + }, nil } // ID returns an Action ID. @@ -77,23 +92,13 @@ func (a *mysqlExplainAction) Type() string { return "mysql-explain" } +// DSN returns a DSN for the Action. +func (a *mysqlExplainAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlExplainAction) Run(ctx context.Context) ([]byte, error) { - if a.params.Query == "" { - return nil, errors.New("Query to EXPLAIN is empty") - } - - // You cant run Explain on trimmed queries. - if strings.HasSuffix(a.params.Query, "...") { - return nil, errors.New("EXPLAIN failed because the query was too long and trimmed. Set max-query-length to a larger value.") //nolint:revive - } - - // Explain is supported only for DML queries. - // https://dev.mysql.com/doc/refman/8.0/en/using-explain.html - if !isDMLQuery(a.params.Query) { - return nil, errors.New("Functionality EXPLAIN is supported only for DML queries (SELECT, INSERT, UPDATE, DELETE, REPLACE)") - } - a.params.Query = queryparser.GetMySQLFingerprintFromExplainFingerprint(a.params.Query) // query has a copy of the original params.Query field if the query is a SELECT or the equivalent diff --git a/agent/runner/actions/mysql_explain_action_test.go b/agent/runner/actions/mysql_explain_action_test.go index b8e58f5b9a..ecc628fbe4 100644 --- a/agent/runner/actions/mysql_explain_action_test.go +++ b/agent/runner/actions/mysql_explain_action_test.go @@ -52,7 +52,9 @@ func TestMySQLExplain(t *testing.T) { Query: query, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() @@ -79,7 +81,9 @@ func TestMySQLExplain(t *testing.T) { Query: query, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_JSON, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() @@ -125,7 +129,9 @@ func TestMySQLExplain(t *testing.T) { Query: query, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_TRADITIONAL_JSON, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() @@ -167,13 +173,9 @@ func TestMySQLExplain(t *testing.T) { Dsn: "pmm-agent:pmm-agent-wrong-password@tcp(127.0.0.1:3306)/world", OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) - ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) - defer cancel() - - _, err := a.Run(ctx) - require.Error(t, err) - assert.Regexp(t, `Query to EXPLAIN is empty`, err.Error()) + a, err := NewMySQLExplainAction("", time.Second, params) + assert.ErrorContains(t, err, `Query to EXPLAIN is empty`) + assert.Nil(t, a) }) t.Run("DML Query Insert", func(t *testing.T) { @@ -184,7 +186,9 @@ func TestMySQLExplain(t *testing.T) { Query: `INSERT INTO city (Name) VALUES ('Rosario')`, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() @@ -205,12 +209,9 @@ func TestMySQLExplain(t *testing.T) { Query: `INSERT INTO city (Name)...`, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) - ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) - defer cancel() - - _, err := a.Run(ctx) - require.Error(t, err, "EXPLAIN failed because the query was too long and trimmed. Set max-query-length to a larger value.") + a, err := NewMySQLExplainAction("", time.Second, params) + assert.ErrorContains(t, err, "EXPLAIN failed because the query exceeded max length and got trimmed. Set max-query-length to a larger value.") + assert.Nil(t, a) }) t.Run("LittleBobbyTables", func(t *testing.T) { @@ -233,11 +234,13 @@ func TestMySQLExplain(t *testing.T) { Query: `SELECT 1; DROP TABLE city; --`, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() - _, err := a.Run(ctx) + _, err = a.Run(ctx) expected := "Error 1064 \\(42000\\): You have an error in your SQL syntax; check the manual that corresponds " + "to your (MySQL|MariaDB) server version for the right syntax to use near 'DROP TABLE city; --' at line 1" require.Error(t, err) @@ -253,11 +256,13 @@ func TestMySQLExplain(t *testing.T) { Query: `DELETE FROM city`, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() - _, err := a.Run(ctx) + _, err = a.Run(ctx) require.NoError(t, err) checkCity(t) }) @@ -305,11 +310,13 @@ func TestMySQLExplain(t *testing.T) { Query: `select * from (select cleanup()) as testclean;`, OutputFormat: agentpb.MysqlExplainOutputFormat_MYSQL_EXPLAIN_OUTPUT_FORMAT_DEFAULT, } - a := NewMySQLExplainAction("", time.Second, params) + a, err := NewMySQLExplainAction("", time.Second, params) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), a.Timeout()) defer cancel() - _, err := a.Run(ctx) + _, err = a.Run(ctx) require.NoError(t, err) check(t) }) diff --git a/agent/runner/actions/mysql_query_select_action.go b/agent/runner/actions/mysql_query_select_action.go index 39833e5b11..9db5feed2d 100644 --- a/agent/runner/actions/mysql_query_select_action.go +++ b/agent/runner/actions/mysql_query_select_action.go @@ -55,6 +55,11 @@ func (a *mysqlQuerySelectAction) Type() string { return "mysql-query-select" } +// DSN returns a DSN for the Action. +func (a *mysqlQuerySelectAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlQuerySelectAction) Run(ctx context.Context) ([]byte, error) { db, err := mysqlOpen(a.params.Dsn, a.params.TlsFiles) diff --git a/agent/runner/actions/mysql_query_show_action.go b/agent/runner/actions/mysql_query_show_action.go index 4081e1898e..cb14cbf309 100644 --- a/agent/runner/actions/mysql_query_show_action.go +++ b/agent/runner/actions/mysql_query_show_action.go @@ -55,6 +55,11 @@ func (a *mysqlQueryShowAction) Type() string { return "mysql-query-show" } +// DSN returns a DSN for the Action. +func (a *mysqlQueryShowAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlQueryShowAction) Run(ctx context.Context) ([]byte, error) { db, err := mysqlOpen(a.params.Dsn, a.params.TlsFiles) diff --git a/agent/runner/actions/mysql_show_create_table_action.go b/agent/runner/actions/mysql_show_create_table_action.go index 7ca727f2c6..1d98f24c0f 100644 --- a/agent/runner/actions/mysql_show_create_table_action.go +++ b/agent/runner/actions/mysql_show_create_table_action.go @@ -53,6 +53,11 @@ func (a *mysqlShowCreateTableAction) Type() string { return "mysql-show-create-table" } +// DSN returns a DSN for the Action. +func (a *mysqlShowCreateTableAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlShowCreateTableAction) Run(ctx context.Context) ([]byte, error) { db, err := mysqlOpen(a.params.Dsn, a.params.TlsFiles) diff --git a/agent/runner/actions/mysql_show_index_action.go b/agent/runner/actions/mysql_show_index_action.go index 1a2bc070fe..112c6814b2 100644 --- a/agent/runner/actions/mysql_show_index_action.go +++ b/agent/runner/actions/mysql_show_index_action.go @@ -54,6 +54,11 @@ func (a *mysqlShowIndexAction) Type() string { return "mysql-show-index" } +// DSN returns a DSN for the Action. +func (a *mysqlShowIndexAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlShowIndexAction) Run(ctx context.Context) ([]byte, error) { db, err := mysqlOpen(a.params.Dsn, a.params.TlsFiles) diff --git a/agent/runner/actions/mysql_show_table_status_action.go b/agent/runner/actions/mysql_show_table_status_action.go index 1d7c5331ae..9056a03600 100644 --- a/agent/runner/actions/mysql_show_table_status_action.go +++ b/agent/runner/actions/mysql_show_table_status_action.go @@ -58,6 +58,11 @@ func (a *mysqlShowTableStatusAction) Type() string { return "mysql-show-table-status" } +// DSN returns a DSN for the Action. +func (a *mysqlShowTableStatusAction) DSN() string { + return a.params.Dsn +} + // Run runs an Action and returns output and error. func (a *mysqlShowTableStatusAction) Run(ctx context.Context) ([]byte, error) { db, err := mysqlOpen(a.params.Dsn, a.params.TlsFiles) diff --git a/agent/runner/actions/postgresql_query_select_action.go b/agent/runner/actions/postgresql_query_select_action.go index 4e21de6582..196cb6ab02 100644 --- a/agent/runner/actions/postgresql_query_select_action.go +++ b/agent/runner/actions/postgresql_query_select_action.go @@ -29,21 +29,38 @@ import ( "github.com/percona/pmm/utils/sqlrows" ) +const postgreSQLQuerySelectActionType = "postgresql-query-select" + type postgresqlQuerySelectAction struct { id string timeout time.Duration params *agentpb.StartActionRequest_PostgreSQLQuerySelectParams - tempDir string + dsn string } // NewPostgreSQLQuerySelectAction creates PostgreSQL SELECT query Action. -func NewPostgreSQLQuerySelectAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLQuerySelectParams, tempDir string) Action { +func NewPostgreSQLQuerySelectAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLQuerySelectParams, tempDir string) (Action, error) { + // A very basic check that there is a single SELECT query. It has oblivious false positives (`SELECT ';'`), + // but PostgreSQL query lexical structure (https://www.postgresql.org/docs/current/sql-syntax-lexical.html) + // does not allow false negatives. + // If we decide to improve it, we could use our existing query parser from pg_stat_statement agent, + // or use a simple hand-made parser similar to + // https://github.com/mc2soft/pq-types/blob/ada769d4011a027a5385b9c4e47976fe327350a6/string_array.go#L82-L116 + if strings.Contains(params.Query, ";") { + return nil, errors.New("query contains ';'") + } + + dsn, err := templates.RenderDSN(params.Dsn, params.TlsFiles, filepath.Join(tempDir, postgreSQLQuerySelectActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &postgresqlQuerySelectAction{ id: id, timeout: timeout, params: params, - tempDir: tempDir, - } + dsn: dsn, + }, nil } // ID returns an Action ID. @@ -58,33 +75,23 @@ func (a *postgresqlQuerySelectAction) Timeout() time.Duration { // Type returns an Action type. func (a *postgresqlQuerySelectAction) Type() string { - return "postgresql-query-select" + return postgreSQLQuerySelectActionType +} + +// DSN returns the DSN for the Action. +func (a *postgresqlQuerySelectAction) DSN() string { + return a.dsn } // Run runs an Action and returns output and error. func (a *postgresqlQuerySelectAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.params.Dsn, a.params.TlsFiles, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - connector, err := pq.NewConnector(dsn) + connector, err := pq.NewConnector(a.dsn) if err != nil { return nil, errors.WithStack(err) } db := sql.OpenDB(connector) defer db.Close() //nolint:errcheck - // A very basic check that there is a single SELECT query. It has oblivious false positives (`SELECT ';'`), - // but PostgreSQL query lexical structure (https://www.postgresql.org/docs/current/sql-syntax-lexical.html) - // does not allow false negatives. - // If we decide to improve it, we could use our existing query parser from pg_stat_statement agent, - // or use a simple hand-made parser similar to - // https://github.com/mc2soft/pq-types/blob/ada769d4011a027a5385b9c4e47976fe327350a6/string_array.go#L82-L116 - if strings.Contains(a.params.Query, ";") { - return nil, errors.New("query contains ';'") - } - rows, err := db.QueryContext(ctx, "SELECT /* pmm-agent */ "+a.params.Query) //nolint:gosec if err != nil { return nil, errors.WithStack(err) diff --git a/agent/runner/actions/postgresql_query_select_action_test.go b/agent/runner/actions/postgresql_query_select_action_test.go index 3f52823a17..c62f559c82 100644 --- a/agent/runner/actions/postgresql_query_select_action_test.go +++ b/agent/runner/actions/postgresql_query_select_action_test.go @@ -41,7 +41,9 @@ func TestPostgreSQLQuerySelect(t *testing.T) { Dsn: dsn, Query: "* FROM pg_extension", } - a := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -74,7 +76,9 @@ func TestPostgreSQLQuerySelect(t *testing.T) { Dsn: dsn, Query: `'\x0001feff'::bytea AS bytes`, } - a := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -98,17 +102,8 @@ func TestPostgreSQLQuerySelect(t *testing.T) { Dsn: dsn, Query: "* FROM city; DROP TABLE city CASCADE; --", } - a := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - b, err := a.Run(ctx) + a, err := NewPostgreSQLQuerySelectAction("", 0, params, os.TempDir()) assert.EqualError(t, err, "query contains ';'") - assert.Nil(t, b) - - var count int - err = db.QueryRow("SELECT COUNT(*) FROM city").Scan(&count) - require.NoError(t, err) - assert.Equal(t, 4079, count) + assert.Nil(t, a) }) } diff --git a/agent/runner/actions/postgresql_query_show_action.go b/agent/runner/actions/postgresql_query_show_action.go index c4ab4e2fdb..8e365e160d 100644 --- a/agent/runner/actions/postgresql_query_show_action.go +++ b/agent/runner/actions/postgresql_query_show_action.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "path/filepath" - "strings" "time" "github.com/lib/pq" @@ -29,21 +28,26 @@ import ( "github.com/percona/pmm/utils/sqlrows" ) +const postgreSQLQueryShowActionType = "postgresql-query-show" + type postgresqlQueryShowAction struct { id string timeout time.Duration - params *agentpb.StartActionRequest_PostgreSQLQueryShowParams - tempDir string + dsn string } // NewPostgreSQLQueryShowAction creates PostgreSQL SHOW query Action. -func NewPostgreSQLQueryShowAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLQueryShowParams, tempDir string) Action { +func NewPostgreSQLQueryShowAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLQueryShowParams, tempDir string) (Action, error) { + dsn, err := templates.RenderDSN(params.Dsn, params.TlsFiles, filepath.Join(tempDir, postgreSQLQueryShowActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &postgresqlQueryShowAction{ id: id, timeout: timeout, - params: params, - tempDir: tempDir, - } + dsn: dsn, + }, nil } // ID returns an Action ID. @@ -58,17 +62,17 @@ func (a *postgresqlQueryShowAction) Timeout() time.Duration { // Type returns an Action type. func (a *postgresqlQueryShowAction) Type() string { - return "postgresql-query-show" + return postgreSQLQueryShowActionType +} + +// DSN returns a DSN for the Action. +func (a *postgresqlQueryShowAction) DSN() string { + return a.dsn } // Run runs an Action and returns output and error. func (a *postgresqlQueryShowAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.params.Dsn, a.params.TlsFiles, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - connector, err := pq.NewConnector(dsn) + connector, err := pq.NewConnector(a.dsn) if err != nil { return nil, errors.WithStack(err) } diff --git a/agent/runner/actions/postgresql_query_show_action_test.go b/agent/runner/actions/postgresql_query_show_action_test.go index afd1196b87..8b0fd5d5f8 100644 --- a/agent/runner/actions/postgresql_query_show_action_test.go +++ b/agent/runner/actions/postgresql_query_show_action_test.go @@ -40,7 +40,9 @@ func TestPostgreSQLQueryShow(t *testing.T) { params := &agentpb.StartActionRequest_PostgreSQLQueryShowParams{ Dsn: dsn, } - a := NewPostgreSQLQueryShowAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLQueryShowAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/agent/runner/actions/postgresql_show_create_table_action.go b/agent/runner/actions/postgresql_show_create_table_action.go index ad15129432..7decddedc3 100644 --- a/agent/runner/actions/postgresql_show_create_table_action.go +++ b/agent/runner/actions/postgresql_show_create_table_action.go @@ -34,6 +34,8 @@ import ( "github.com/percona/pmm/api/agentpb" ) +const postgreSQLShowCreateTableActionType = "postgresql-show-create-table" + type columnInfo struct { Attname string FormatType string @@ -66,18 +68,28 @@ type postgresqlShowCreateTableAction struct { id string timeout time.Duration params *agentpb.StartActionRequest_PostgreSQLShowCreateTableParams - tempDir string + dsn string } // NewPostgreSQLShowCreateTableAction creates PostgreSQL SHOW CREATE TABLE Action. // This is an Action that can run `\d+ table` command analog on PostgreSQL service with given DSN. -func NewPostgreSQLShowCreateTableAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLShowCreateTableParams, tempDir string) Action { +func NewPostgreSQLShowCreateTableAction( + id string, + timeout time.Duration, + params *agentpb.StartActionRequest_PostgreSQLShowCreateTableParams, + tempDir string, +) (Action, error) { + dsn, err := templates.RenderDSN(params.Dsn, params.TlsFiles, filepath.Join(tempDir, postgreSQLShowCreateTableActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &postgresqlShowCreateTableAction{ id: id, timeout: timeout, params: params, - tempDir: tempDir, - } + dsn: dsn, + }, nil } // ID returns an Action ID. @@ -92,17 +104,17 @@ func (a *postgresqlShowCreateTableAction) Timeout() time.Duration { // Type returns an Action type. func (a *postgresqlShowCreateTableAction) Type() string { - return "postgresql-show-create-table" + return postgreSQLShowCreateTableActionType +} + +// DSN returns a DSN for the Action. +func (a *postgresqlShowCreateTableAction) DSN() string { + return a.dsn } // Run runs an Action and returns output and error. func (a *postgresqlShowCreateTableAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.params.Dsn, a.params.TlsFiles, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - connector, err := pq.NewConnector(dsn) + connector, err := pq.NewConnector(a.dsn) if err != nil { return nil, errors.WithStack(err) } diff --git a/agent/runner/actions/postgresql_show_create_table_action_test.go b/agent/runner/actions/postgresql_show_create_table_action_test.go index 176b70ba9e..af5317e4a5 100644 --- a/agent/runner/actions/postgresql_show_create_table_action_test.go +++ b/agent/runner/actions/postgresql_show_create_table_action_test.go @@ -40,7 +40,9 @@ func TestPostgreSQLShowCreateTable(t *testing.T) { Dsn: dsn, Table: "public.country", } - a := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -83,7 +85,9 @@ Referenced by: Dsn: dsn, Table: "city", } - a := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -112,7 +116,9 @@ Referenced by: Dsn: dsn, Table: "countrylanguage", } - a := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -140,11 +146,13 @@ Foreign-key constraints: Dsn: dsn, Table: `city; DROP TABLE city; --`, } - a := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowCreateTableAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err := a.Run(ctx) + _, err = a.Run(ctx) expected := "Table not found: sql: no rows in result set" assert.EqualError(t, err, expected) diff --git a/agent/runner/actions/postgresql_show_index_action.go b/agent/runner/actions/postgresql_show_index_action.go index b91d823d96..d93eb8d493 100644 --- a/agent/runner/actions/postgresql_show_index_action.go +++ b/agent/runner/actions/postgresql_show_index_action.go @@ -30,22 +30,29 @@ import ( "github.com/percona/pmm/utils/sqlrows" ) +const postgreSQLShowIndexActionType = "postgresql-show-index" + type postgresqlShowIndexAction struct { id string timeout time.Duration params *agentpb.StartActionRequest_PostgreSQLShowIndexParams - tempDir string + dsn string } // NewPostgreSQLShowIndexAction creates PostgreSQL SHOW INDEX Action. // This is an Action that can run `SHOW INDEX` command on PostgreSQL service with given DSN. -func NewPostgreSQLShowIndexAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLShowIndexParams, tempDir string) Action { +func NewPostgreSQLShowIndexAction(id string, timeout time.Duration, params *agentpb.StartActionRequest_PostgreSQLShowIndexParams, tempDir string) (Action, error) { + dsn, err := templates.RenderDSN(params.Dsn, params.TlsFiles, filepath.Join(tempDir, postgreSQLShowIndexActionType, id)) + if err != nil { + return nil, errors.WithStack(err) + } + return &postgresqlShowIndexAction{ id: id, timeout: timeout, params: params, - tempDir: tempDir, - } + dsn: dsn, + }, nil } // ID returns an Action ID. @@ -60,17 +67,17 @@ func (a *postgresqlShowIndexAction) Timeout() time.Duration { // Type returns an Action type. func (a *postgresqlShowIndexAction) Type() string { - return "postgresql-show-index" + return postgreSQLShowIndexActionType +} + +// DSN returns a DSN for the Action. +func (a *postgresqlShowIndexAction) DSN() string { + return a.dsn } // Run runs an Action and returns output and error. func (a *postgresqlShowIndexAction) Run(ctx context.Context) ([]byte, error) { - dsn, err := templates.RenderDSN(a.params.Dsn, a.params.TlsFiles, filepath.Join(a.tempDir, strings.ToLower(a.Type()), a.id)) - if err != nil { - return nil, errors.WithStack(err) - } - - connector, err := pq.NewConnector(dsn) + connector, err := pq.NewConnector(a.dsn) if err != nil { return nil, errors.WithStack(err) } diff --git a/agent/runner/actions/postgresql_show_index_action_test.go b/agent/runner/actions/postgresql_show_index_action_test.go index 38516865d1..a80c397cc0 100644 --- a/agent/runner/actions/postgresql_show_index_action_test.go +++ b/agent/runner/actions/postgresql_show_index_action_test.go @@ -42,7 +42,9 @@ func TestPostgreSQLShowIndex(t *testing.T) { Dsn: dsn, Table: "city", } - a := NewPostgreSQLShowIndexAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowIndexAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -68,7 +70,9 @@ func TestPostgreSQLShowIndex(t *testing.T) { Dsn: dsn, Table: "public.city", } - a := NewPostgreSQLShowIndexAction("", 0, params, os.TempDir()) + a, err := NewPostgreSQLShowIndexAction("", 0, params, os.TempDir()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() diff --git a/agent/runner/actions/process_action.go b/agent/runner/actions/process_action.go index 2b0d16baa9..5d36242483 100644 --- a/agent/runner/actions/process_action.go +++ b/agent/runner/actions/process_action.go @@ -59,6 +59,11 @@ func (a *processAction) Type() string { return a.command } +// DSN returns a DSN for the Action. +func (a *processAction) DSN() string { + return "" // no DSN for process action +} + // Run runs an Action and returns output and error. func (a *processAction) Run(ctx context.Context) ([]byte, error) { cmd := exec.CommandContext(ctx, a.command, a.arg...) //nolint:gosec diff --git a/agent/runner/actions/pt_mysql_summary_action.go b/agent/runner/actions/pt_mysql_summary_action.go index bce06cc268..43a3e0c18c 100644 --- a/agent/runner/actions/pt_mysql_summary_action.go +++ b/agent/runner/actions/pt_mysql_summary_action.go @@ -17,6 +17,7 @@ package actions import ( "context" "fmt" + "net" "os" "os/exec" "strconv" @@ -63,6 +64,15 @@ func (a *ptMySQLSummaryAction) Type() string { return a.command } +// DSN returns a DSN for the Action. +func (a *ptMySQLSummaryAction) DSN() string { + if a.params.Socket != "" { + return a.params.Socket + } + + return net.JoinHostPort(a.params.Host, strconv.FormatUint(uint64(a.params.Port), 10)) +} + // Run runs an Action and returns output and error. func (a *ptMySQLSummaryAction) Run(ctx context.Context) ([]byte, error) { cmd := exec.CommandContext(ctx, a.command, a.ListFromMySQLParams()...) //nolint:gosec diff --git a/agent/runner/jobs/common.go b/agent/runner/jobs/common.go index f110bd309d..f9802723f0 100644 --- a/agent/runner/jobs/common.go +++ b/agent/runner/jobs/common.go @@ -31,22 +31,22 @@ type DBConnConfig struct { Socket string } -func createDBURL(dbConfig DBConnConfig) *url.URL { +func (c *DBConnConfig) createDBURL() *url.URL { var host string switch { - case dbConfig.Address != "": - if dbConfig.Port > 0 { - host = net.JoinHostPort(dbConfig.Address, strconv.Itoa(dbConfig.Port)) + case c.Address != "": + if c.Port > 0 { + host = net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) } else { - host = dbConfig.Address + host = c.Address } - case dbConfig.Socket != "": - host = dbConfig.Socket + case c.Socket != "": + host = c.Socket } var user *url.Userinfo - if dbConfig.User != "" { - user = url.UserPassword(dbConfig.User, dbConfig.Password) + if c.User != "" { + user = url.UserPassword(c.User, c.Password) } return &url.URL{ diff --git a/agent/runner/jobs/job.go b/agent/runner/jobs/job.go index 6d4c2eabdf..81e76582b1 100644 --- a/agent/runner/jobs/job.go +++ b/agent/runner/jobs/job.go @@ -44,6 +44,8 @@ type Job interface { Type() JobType // Timeout returns Job timeout. Timeout() time.Duration + // DSN returns Data Source Name required for the Action. + DSN() string // Run starts Job execution. Run(ctx context.Context, send Send) error } diff --git a/agent/runner/jobs/mongodb_backup_job.go b/agent/runner/jobs/mongodb_backup_job.go index 7c321844b7..afa68f08f1 100644 --- a/agent/runner/jobs/mongodb_backup_job.go +++ b/agent/runner/jobs/mongodb_backup_job.go @@ -44,7 +44,7 @@ type MongoDBBackupJob struct { timeout time.Duration l logrus.FieldLogger name string - dbURL *string + dsn string locationConfig BackupLocationConfig pitr bool dataModel backuppb.DataModel @@ -57,7 +57,7 @@ func NewMongoDBBackupJob( id string, timeout time.Duration, name string, - dbConfig *string, + dsn string, locationConfig BackupLocationConfig, pitr bool, dataModel backuppb.DataModel, @@ -75,11 +75,11 @@ func NewMongoDBBackupJob( timeout: timeout, l: logrus.WithFields(logrus.Fields{"id": id, "type": "mongodb_backup", "name": name}), name: name, - dbURL: dbConfig, + dsn: dsn, locationConfig: locationConfig, pitr: pitr, dataModel: dataModel, - jobLogger: newPbmJobLogger(id, pbmBackupJob, dbConfig), + jobLogger: newPbmJobLogger(id, pbmBackupJob, dsn), folder: folder, }, nil } @@ -99,6 +99,11 @@ func (j *MongoDBBackupJob) Timeout() time.Duration { return j.timeout } +// DSN returns DSN for the Job. +func (j *MongoDBBackupJob) DSN() string { + return j.dsn +} + // Run starts Job execution. func (j *MongoDBBackupJob) Run(ctx context.Context, send Send) error { defer j.jobLogger.sendLog(send, "", true) @@ -121,14 +126,14 @@ func (j *MongoDBBackupJob) Run(ctx context.Context, send Send) error { configParams := pbmConfigParams{ configFilePath: confFile, forceResync: false, - dbURL: j.dbURL, + dsn: j.dsn, } if err := pbmConfigure(ctx, j.l, configParams); err != nil { return errors.Wrap(err, "failed to configure pbm") } rCtx, cancel := context.WithTimeout(ctx, resyncTimeout) - if err := waitForPBMNoRunningOperations(rCtx, j.l, j.dbURL); err != nil { + if err := waitForPBMNoRunningOperations(rCtx, j.l, j.dsn); err != nil { cancel() return errors.Wrap(err, "failed to wait configuration completion") } @@ -148,17 +153,17 @@ func (j *MongoDBBackupJob) Run(ctx context.Context, send Send) error { } }() - if err := waitForPBMBackup(ctx, j.l, j.dbURL, pbmBackupOut.Name); err != nil { + if err := waitForPBMBackup(ctx, j.l, j.dsn, pbmBackupOut.Name); err != nil { j.jobLogger.sendLog(send, err.Error(), false) return errors.Wrap(err, "failed to wait backup completion") } - sharded, err := isShardedCluster(ctx, j.dbURL) + sharded, err := isShardedCluster(ctx, j.dsn) if err != nil { return err } - backupTimestamp, err := pbmGetSnapshotTimestamp(ctx, j.l, j.dbURL, pbmBackupOut.Name) + backupTimestamp, err := pbmGetSnapshotTimestamp(ctx, j.l, j.dsn, pbmBackupOut.Name) if err != nil { return err } @@ -211,7 +216,7 @@ func (j *MongoDBBackupJob) startBackup(ctx context.Context) (*pbmBackup, error) return nil, errors.Errorf("'%s' is not a supported data model for backups", j.dataModel) } - if err := execPBMCommand(ctx, j.dbURL, &result, pbmArgs...); err != nil { + if err := execPBMCommand(ctx, j.dsn, &result, pbmArgs...); err != nil { return nil, err } diff --git a/agent/runner/jobs/mongodb_backup_job_test.go b/agent/runner/jobs/mongodb_backup_job_test.go index 669b3f6580..2748ec5736 100644 --- a/agent/runner/jobs/mongodb_backup_job_test.go +++ b/agent/runner/jobs/mongodb_backup_job_test.go @@ -72,7 +72,7 @@ func TestCreateDBURL(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, test.url, createDBURL(test.dbConfig).String()) + assert.Equal(t, test.url, test.dbConfig.createDBURL().String()) }) } } @@ -80,7 +80,6 @@ func TestCreateDBURL(t *testing.T) { func TestNewMongoDBBackupJob(t *testing.T) { t.Parallel() testJobDuration := 1 * time.Second - var dbConfig string tests := []struct { name string @@ -115,7 +114,7 @@ func TestNewMongoDBBackupJob(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, err := NewMongoDBBackupJob(t.Name(), testJobDuration, t.Name(), &dbConfig, BackupLocationConfig{}, tc.pitr, tc.dataModel, "artifact_folder") + _, err := NewMongoDBBackupJob(t.Name(), testJobDuration, t.Name(), "", BackupLocationConfig{}, tc.pitr, tc.dataModel, "artifact_folder") if tc.errMsg == "" { assert.NoError(t, err) } else { diff --git a/agent/runner/jobs/mongodb_restore_job.go b/agent/runner/jobs/mongodb_restore_job.go index c11ae699d5..2d5e22474f 100644 --- a/agent/runner/jobs/mongodb_restore_job.go +++ b/agent/runner/jobs/mongodb_restore_job.go @@ -42,7 +42,7 @@ type MongoDBRestoreJob struct { l *logrus.Entry name string pitrTimestamp time.Time - dbURL *string + dbURL string locationConfig BackupLocationConfig agentsRestarter agentsRestarter jobLogger *pbmJobLogger @@ -56,7 +56,7 @@ func NewMongoDBRestoreJob( timeout time.Duration, name string, pitrTimestamp time.Time, - dbConfig *string, + dbConfig string, locationConfig BackupLocationConfig, restarter agentsRestarter, folder string, @@ -92,6 +92,11 @@ func (j *MongoDBRestoreJob) Timeout() time.Duration { return j.timeout } +// DSN returns DSN required for the Job. +func (j *MongoDBRestoreJob) DSN() string { + return j.dbURL +} + // Run starts Job execution. func (j *MongoDBRestoreJob) Run(ctx context.Context, send Send) error { defer j.jobLogger.sendLog(send, "", true) @@ -121,7 +126,7 @@ func (j *MongoDBRestoreJob) Run(ctx context.Context, send Send) error { configParams := pbmConfigParams{ configFilePath: confFile, forceResync: true, - dbURL: j.dbURL, + dsn: j.dbURL, } if err := pbmConfigure(ctx, j.l, configParams); err != nil { return errors.Wrap(err, "failed to configure pbm") diff --git a/agent/runner/jobs/mysql_backup_job.go b/agent/runner/jobs/mysql_backup_job.go index fd8b613aee..b27e40e5c8 100644 --- a/agent/runner/jobs/mysql_backup_job.go +++ b/agent/runner/jobs/mysql_backup_job.go @@ -76,6 +76,11 @@ func (j *MySQLBackupJob) Timeout() time.Duration { return j.timeout } +// DSN returns DSN for the Job. +func (j *MySQLBackupJob) DSN() string { + return j.connConf.createDBURL().String() +} + // Run starts Job execution. func (j *MySQLBackupJob) Run(ctx context.Context, send Send) error { if err := j.binariesInstalled(); err != nil { diff --git a/agent/runner/jobs/mysql_restore_job.go b/agent/runner/jobs/mysql_restore_job.go index d34e20b802..297ebd55ca 100644 --- a/agent/runner/jobs/mysql_restore_job.go +++ b/agent/runner/jobs/mysql_restore_job.go @@ -82,6 +82,11 @@ func (j *MySQLRestoreJob) Timeout() time.Duration { return j.timeout } +// DSN returns DSN for the Job. +func (j *MySQLRestoreJob) DSN() string { + return "" // not used for MySQL restore +} + // Run executes backup restore steps. func (j *MySQLRestoreJob) Run(ctx context.Context, send Send) error { if j.locationConfig.S3Config == nil { diff --git a/agent/runner/jobs/pbm_helpers.go b/agent/runner/jobs/pbm_helpers.go index 4161c5816a..859c6827ce 100644 --- a/agent/runner/jobs/pbm_helpers.go +++ b/agent/runner/jobs/pbm_helpers.go @@ -173,15 +173,15 @@ type pbmError struct { type pbmConfigParams struct { configFilePath string forceResync bool - dbURL *string + dsn string } -func execPBMCommand(ctx context.Context, dbURL *string, to interface{}, args ...string) error { +func execPBMCommand(ctx context.Context, dsn string, to interface{}, args ...string) error { nCtx, cancel := context.WithTimeout(ctx, cmdTimeout) defer cancel() - args = append(args, "--out=json", "--mongodb-uri="+*dbURL) - cmd := exec.CommandContext(nCtx, pbmBin, args...) // #nosec G204 + args = append(args, "--out=json", "--mongodb-uri="+dsn) //nolint:goconst + cmd := exec.CommandContext(nCtx, pbmBin, args...) // #nosec G204 b, err := cmd.Output() if err != nil { @@ -198,17 +198,17 @@ func execPBMCommand(ctx context.Context, dbURL *string, to interface{}, args ... return json.Unmarshal(b, to) } -func retrieveLogs(ctx context.Context, dbURL *string, event string) ([]pbmLogEntry, error) { +func retrieveLogs(ctx context.Context, dsn string, event string) ([]pbmLogEntry, error) { var logs []pbmLogEntry - if err := execPBMCommand(ctx, dbURL, &logs, "logs", "--event="+event, "--tail=0"); err != nil { + if err := execPBMCommand(ctx, dsn, &logs, "logs", "--event="+event, "--tail=0"); err != nil { return nil, err } return logs, nil } -func waitForPBMNoRunningOperations(ctx context.Context, l logrus.FieldLogger, dbURL *string) error { +func waitForPBMNoRunningOperations(ctx context.Context, l logrus.FieldLogger, dsn string) error { l.Info("Waiting for no running pbm operations.") ticker := time.NewTicker(statusCheckInterval) @@ -217,7 +217,7 @@ func waitForPBMNoRunningOperations(ctx context.Context, l logrus.FieldLogger, db for { select { case <-ticker.C: - status, err := getPBMStatus(ctx, dbURL) + status, err := getPBMStatus(ctx, dsn) if err != nil { return err } @@ -230,8 +230,8 @@ func waitForPBMNoRunningOperations(ctx context.Context, l logrus.FieldLogger, db } } -func isShardedCluster(ctx context.Context, dbURL *string) (bool, error) { - status, err := getPBMStatus(ctx, dbURL) +func isShardedCluster(ctx context.Context, dsn string) (bool, error) { + status, err := getPBMStatus(ctx, dsn) if err != nil { return false, err } @@ -243,15 +243,15 @@ func isShardedCluster(ctx context.Context, dbURL *string) (bool, error) { return false, nil } -func getPBMStatus(ctx context.Context, dbURL *string) (*pbmStatus, error) { +func getPBMStatus(ctx context.Context, dsn string) (*pbmStatus, error) { var status pbmStatus - if err := execPBMCommand(ctx, dbURL, &status, "status"); err != nil { + if err := execPBMCommand(ctx, dsn, &status, "status"); err != nil { return nil, errors.Wrap(err, "pbm status error") } return &status, nil } -func waitForPBMBackup(ctx context.Context, l logrus.FieldLogger, dbURL *string, name string) error { +func waitForPBMBackup(ctx context.Context, l logrus.FieldLogger, dsn string, name string) error { l.Infof("waiting for pbm backup: %s", name) ticker := time.NewTicker(statusCheckInterval) defer ticker.Stop() @@ -262,7 +262,7 @@ func waitForPBMBackup(ctx context.Context, l logrus.FieldLogger, dbURL *string, select { case <-ticker.C: var info describeInfo - err := execPBMCommand(ctx, dbURL, &info, "describe-backup", name) + err := execPBMCommand(ctx, dsn, &info, "describe-backup", name) if err != nil { // for the first couple of seconds after backup process starts describe-backup command may return this error if (strings.HasSuffix(err.Error(), "no such file") || @@ -311,7 +311,7 @@ func findPITRRestore(list []pbmListRestore, restoreInfoPITRTime int64, startedAt return nil } -func findPITRRestoreName(ctx context.Context, dbURL *string, restoreInfo *pbmRestore) (string, error) { +func findPITRRestoreName(ctx context.Context, dsn string, restoreInfo *pbmRestore) (string, error) { restoreInfoPITRTime, err := time.Parse("2006-01-02T15:04:05", restoreInfo.PITR) if err != nil { return "", err @@ -326,7 +326,7 @@ func findPITRRestoreName(ctx context.Context, dbURL *string, restoreInfo *pbmRes case <-ticker.C: checks++ var list []pbmListRestore - if err := execPBMCommand(ctx, dbURL, &list, "list", "--restore"); err != nil { + if err := execPBMCommand(ctx, dsn, &list, "list", "--restore"); err != nil { return "", errors.Wrapf(err, "pbm status error") } entry := findPITRRestore(list, restoreInfoPITRTime.Unix(), restoreInfo.StartedAt) @@ -344,14 +344,14 @@ func findPITRRestoreName(ctx context.Context, dbURL *string, restoreInfo *pbmRes } } -func waitForPBMRestore(ctx context.Context, l logrus.FieldLogger, dbURL *string, restoreInfo *pbmRestore, backupType, confFile string) error { +func waitForPBMRestore(ctx context.Context, l logrus.FieldLogger, dsn string, restoreInfo *pbmRestore, backupType, confFile string) error { l.Infof("Detecting restore name") var name string var err error // @TODO Do like this until https://jira.percona.com/browse/PBM-723 is not done. if restoreInfo.PITR != "" { // TODO add more checks of PBM responses. - name, err = findPITRRestoreName(ctx, dbURL, restoreInfo) + name, err = findPITRRestoreName(ctx, dsn, restoreInfo) if err != nil { return err } @@ -370,9 +370,9 @@ func waitForPBMRestore(ctx context.Context, l logrus.FieldLogger, dbURL *string, case <-ticker.C: var info describeInfo if backupType == "physical" { - err = execPBMCommand(ctx, dbURL, &info, "describe-restore", name, "--config="+confFile) + err = execPBMCommand(ctx, dsn, &info, "describe-restore", name, "--config="+confFile) } else { - err = execPBMCommand(ctx, dbURL, &info, "describe-restore", name) + err = execPBMCommand(ctx, dsn, &info, "describe-restore", name) } if err != nil { if maxRetryCount > 0 { @@ -412,7 +412,7 @@ func pbmConfigure(ctx context.Context, l logrus.FieldLogger, params pbmConfigPar args := []string{ "config", "--out=json", - "--mongodb-uri=" + *params.dbURL, + "--mongodb-uri=" + params.dsn, "--file=" + params.configFilePath, } @@ -425,7 +425,7 @@ func pbmConfigure(ctx context.Context, l logrus.FieldLogger, params pbmConfigPar args := []string{ "config", "--out=json", - "--mongodb-uri=" + *params.dbURL, + "--mongodb-uri=" + params.dsn, "--force-resync", } output, err := exec.CommandContext(nCtx, pbmBin, args...).CombinedOutput() //nolint:gosec @@ -549,8 +549,8 @@ func groupPartlyDoneErrors(info describeInfo) error { } // pbmGetSnapshotTimestamp returns time the backup restores target db to. -func pbmGetSnapshotTimestamp(ctx context.Context, l logrus.FieldLogger, dbURL *string, backupName string) (*time.Time, error) { - snapshots, err := getSnapshots(ctx, l, dbURL) +func pbmGetSnapshotTimestamp(ctx context.Context, l logrus.FieldLogger, dsn string, backupName string) (*time.Time, error) { + snapshots, err := getSnapshots(ctx, l, dsn) if err != nil { return nil, err } @@ -565,7 +565,7 @@ func pbmGetSnapshotTimestamp(ctx context.Context, l logrus.FieldLogger, dbURL *s } // getSnapshots returns all PBM snapshots found in configured location. -func getSnapshots(ctx context.Context, l logrus.FieldLogger, dbURL *string) ([]pbmSnapshot, error) { +func getSnapshots(ctx context.Context, l logrus.FieldLogger, dsn string) ([]pbmSnapshot, error) { // Sometimes PBM returns empty list of snapshots, that's why we're trying to get them several times. ticker := time.NewTicker(listCheckInterval) defer ticker.Stop() @@ -575,7 +575,7 @@ func getSnapshots(ctx context.Context, l logrus.FieldLogger, dbURL *string) ([]p select { case <-ticker.C: checks++ - status, err := getPBMStatus(ctx, dbURL) + status, err := getPBMStatus(ctx, dsn) if err != nil { return nil, err } diff --git a/agent/runner/jobs/pbm_job_logger.go b/agent/runner/jobs/pbm_job_logger.go index e1a1511018..32b70d58ba 100644 --- a/agent/runner/jobs/pbm_job_logger.go +++ b/agent/runner/jobs/pbm_job_logger.go @@ -34,13 +34,13 @@ const ( ) type pbmJobLogger struct { - dbURL *string + dbURL string jobID string jobType pbmJob logChunkID uint32 } -func newPbmJobLogger(jobID string, jobType pbmJob, mongoURL *string) *pbmJobLogger { +func newPbmJobLogger(jobID string, jobType pbmJob, mongoURL string) *pbmJobLogger { return &pbmJobLogger{ jobID: jobID, jobType: jobType, diff --git a/agent/runner/runner.go b/agent/runner/runner.go index 8b89afe18b..f21d2ff93d 100644 --- a/agent/runner/runner.go +++ b/agent/runner/runner.go @@ -17,8 +17,12 @@ package runner import ( "context" + "crypto/sha256" + "encoding/base64" + "net/url" "runtime/pprof" "sync" + "sync/atomic" "time" "github.com/pkg/errors" @@ -35,7 +39,8 @@ import ( const ( bufferSize = 256 defaultActionTimeout = 10 * time.Second // default timeout for compatibility with an older server - defaultCapacity = 32 + defaultTotalCapacity = 32 // how many concurrent operations are allowed in total + defaultTokenCapacity = 2 // how many concurrent operations on a single resource (usually DB instance) are allowed ) // Runner executes jobs and actions. @@ -48,33 +53,125 @@ type Runner struct { actionsMessages chan agentpb.AgentRequestPayload jobsMessages chan agentpb.AgentResponsePayload - sem *semaphore.Weighted - wg sync.WaitGroup + wg sync.WaitGroup - rw sync.RWMutex - rCancel map[string]context.CancelFunc + // cancels holds cancel functions for running actions and jobs. + cancelsM sync.RWMutex + cancels map[string]context.CancelFunc + + // running holds IDs of running actions and jobs. + runningM sync.RWMutex + running map[string]struct{} + + // gSem is a global semaphore to limit total number of concurrent operations performed by the runner. + gSem *semaphore.Weighted + + // tokenCapacity is a limit of concurrent operations on a single resource, usually database instance. + tokenCapacity uint16 + + // lSems is a map of local semaphores to limit number of concurrent operations on a single database instance. + // Key is a token which is typically is a hash of DSN(only host:port pair), value is a semaphore. + lSemsM sync.Mutex + lSems map[string]*entry +} + +// entry stores local semaphore and its counter. +type entry struct { + count atomic.Int32 + sem *semaphore.Weighted } // New creates new runner. If capacity is 0 then default value is used. -func New(capacity uint16) *Runner { +func New(totalCapacity, tokenCapacity uint16) *Runner { l := logrus.WithField("component", "runner") - if capacity == 0 { - capacity = defaultCapacity + if totalCapacity == 0 { + totalCapacity = defaultTotalCapacity + } + + if tokenCapacity == 0 { + tokenCapacity = defaultTokenCapacity } - l.Infof("Runner capacity set to %d.", capacity) + l.Infof("Runner capacity set to %d, token capacity set to %d", totalCapacity, tokenCapacity) return &Runner{ l: l, actions: make(chan actions.Action, bufferSize), jobs: make(chan jobs.Job, bufferSize), - sem: semaphore.NewWeighted(int64(capacity)), - rCancel: make(map[string]context.CancelFunc), + cancels: make(map[string]context.CancelFunc), + running: make(map[string]struct{}), jobsMessages: make(chan agentpb.AgentResponsePayload), actionsMessages: make(chan agentpb.AgentRequestPayload), + tokenCapacity: tokenCapacity, + gSem: semaphore.NewWeighted(int64(totalCapacity)), + lSems: make(map[string]*entry), + } +} + +// acquire acquires global and local semaphores. +func (r *Runner) acquire(ctx context.Context, token string) error { + if err := r.acquireL(ctx, token); err != nil { + return err + } + + if err := r.gSem.Acquire(ctx, 1); err != nil { + r.releaseL(token) + return err + } + + return nil +} + +// release releases global and local semaphores. +func (r *Runner) release(token string) { + r.gSem.Release(1) + + r.releaseL(token) +} + +// acquireL acquires local semaphore for given token. +func (r *Runner) acquireL(ctx context.Context, token string) error { + if token != "" { + r.lSemsM.Lock() + + e, ok := r.lSems[token] + if !ok { + e = &entry{sem: semaphore.NewWeighted(int64(r.tokenCapacity))} + r.lSems[token] = e + } + r.lSemsM.Unlock() + + if err := e.sem.Acquire(ctx, 1); err != nil { + return err + } + e.count.Add(1) + } + + return nil +} + +// releaseL releases local semaphore for given token. +func (r *Runner) releaseL(token string) { + if token != "" { + r.lSemsM.Lock() + + if e, ok := r.lSems[token]; ok { + e.sem.Release(1) + if v := e.count.Add(-1); v == 0 { + delete(r.lSems, token) + } + } + r.lSemsM.Unlock() } } +// lSemsLen returns number of local semaphores in use. +func (r *Runner) lSemsLen() int { + r.lSemsM.Lock() + defer r.lSemsM.Unlock() + return len(r.lSems) +} + // Run starts jobs execution loop. It reads jobs from the channel and starts them in separate goroutines. func (r *Runner) Run(ctx context.Context) { for { @@ -124,65 +221,100 @@ func (r *Runner) ActionsResults() <-chan agentpb.AgentRequestPayload { // Stop stops running Action or Job. func (r *Runner) Stop(id string) { - r.rw.RLock() - defer r.rw.RUnlock() + r.cancelsM.RLock() + defer r.cancelsM.RUnlock() - // Job removes itself from rCancel map. So here we only invoke cancel. - if cancel, ok := r.rCancel[id]; ok { + // Job removes itself from cancels map. So here we only invoke cancel. + if cancel, ok := r.cancels[id]; ok { cancel() } } // IsRunning returns true if Action or Job with given ID still running. func (r *Runner) IsRunning(id string) bool { - r.rw.RLock() - defer r.rw.RUnlock() - _, ok := r.rCancel[id] + r.runningM.RLock() + defer r.runningM.RUnlock() + _, ok := r.running[id] return ok } +// createTokenFromDSN returns unique database instance id (token) calculated as a hash from host:port part of the DSN. +func createTokenFromDSN(dsn string) (string, error) { + if dsn == "" { + return "", nil + } + u, err := url.Parse(dsn) + if err != nil { + return "", errors.Wrap(err, "failed to parse DSN") + } + + host := u.Host + // If host is empty, use the whole DSN for hash calculation. + // It can give worse granularity, but it's better than nothing. + if host == "" { + host = dsn + } + + h := sha256.New() + h.Write([]byte(host)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)), nil +} + func (r *Runner) handleJob(ctx context.Context, job jobs.Job) { jobID, jobType := job.ID(), job.Type() l := r.l.WithFields(logrus.Fields{"id": jobID, "type": jobType}) - if err := r.sem.Acquire(ctx, 1); err != nil { - l.Errorf("Failed to acquire token for a job: %v", err) - r.sendJobsMessage(&agentpb.JobResult{ - JobId: job.ID(), - Timestamp: timestamppb.Now(), - Result: &agentpb.JobResult_Error_{ - Error: &agentpb.JobResult_Error{ - Message: err.Error(), - }, - }, - }) - return + token, err := createTokenFromDSN(job.DSN()) + if err != nil { + r.l.Warnf("Failed to get token for job: %v", err) } - var nCtx context.Context - var cancel context.CancelFunc - if timeout := job.Timeout(); timeout != 0 { - nCtx, cancel = context.WithTimeout(ctx, timeout) - } else { - nCtx, cancel = context.WithCancel(ctx) - } + ctx, cancel := context.WithCancel(ctx) r.addCancel(jobID, cancel) r.wg.Add(1) run := func(ctx context.Context) { - l.Infof("Job started.") - defer func(start time.Time) { l.WithField("duration", time.Since(start).String()).Info("Job finished.") }(time.Now()) - defer r.sem.Release(1) defer r.wg.Done() defer cancel() defer r.removeCancel(jobID) - err := job.Run(ctx, r.sendJobsMessage) + l.Debug("Acquiring tokens for a job.") + if err := r.acquire(ctx, token); err != nil { + l.Errorf("Failed to acquire token for a job: %v", err) + r.sendJobsMessage(&agentpb.JobResult{ + JobId: job.ID(), + Timestamp: timestamppb.Now(), + Result: &agentpb.JobResult_Error_{ + Error: &agentpb.JobResult_Error{ + Message: err.Error(), + }, + }, + }) + return + } + defer r.release(token) + + var nCtx context.Context + var nCancel context.CancelFunc + if timeout := job.Timeout(); timeout != 0 { + nCtx, nCancel = context.WithTimeout(ctx, timeout) + defer nCancel() + } else { + // If timeout is not provided then use parent context + nCtx = ctx + } + + // Mark job as running. + r.addStarted(jobID) + defer r.removeStarted(jobID) + l.Info("Job started.") + + err := job.Run(nCtx, r.sendJobsMessage) if err != nil { r.sendJobsMessage(&agentpb.JobResult{ JobId: job.ID(), @@ -197,44 +329,56 @@ func (r *Runner) handleJob(ctx context.Context, job jobs.Job) { } } - go pprof.Do(nCtx, pprof.Labels("jobID", jobID, "type", string(jobType)), run) + go pprof.Do(ctx, pprof.Labels("jobID", jobID, "type", string(jobType)), run) } func (r *Runner) handleAction(ctx context.Context, action actions.Action) { actionID, actionType := action.ID(), action.Type() l := r.l.WithFields(logrus.Fields{"id": actionID, "type": actionType}) - if err := r.sem.Acquire(ctx, 1); err != nil { - l.Errorf("Failed to acquire token for an action: %v", err) - r.sendActionsMessage(&agentpb.ActionResultRequest{ - ActionId: actionID, - Done: true, - Error: err.Error(), - }) - return - } - - var timeout time.Duration - if timeout = action.Timeout(); timeout == 0 { - timeout = defaultActionTimeout + instanceID, err := createTokenFromDSN(action.DSN()) + if err != nil { + r.l.Warnf("Failed to get instance ID for action: %v", err) } - nCtx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := context.WithCancel(ctx) r.addCancel(actionID, cancel) r.wg.Add(1) - run := func(_ context.Context) { - l.Infof("Action started.") - + run := func(ctx context.Context) { defer func(start time.Time) { l.WithField("duration", time.Since(start).String()).Info("Action finished.") }(time.Now()) - defer r.sem.Release(1) defer r.wg.Done() defer cancel() defer r.removeCancel(actionID) + l.Debug("Acquiring tokens for an action.") + if err := r.acquire(ctx, instanceID); err != nil { + l.Errorf("Failed to acquire token for an action: %v", err) + r.sendActionsMessage(&agentpb.ActionResultRequest{ + ActionId: actionID, + Done: true, + Error: err.Error(), + }) + return + } + defer r.release(instanceID) + + var timeout time.Duration + if timeout = action.Timeout(); timeout == 0 { + timeout = defaultActionTimeout + } + + nCtx, nCancel := context.WithTimeout(ctx, timeout) + defer nCancel() + + // Mark action as running. + r.addStarted(actionID) + defer r.removeStarted(actionID) + l.Infof("Action started.") + output, err := action.Run(nCtx) var errMsg string if err != nil { @@ -249,7 +393,7 @@ func (r *Runner) handleAction(ctx context.Context, action actions.Action) { Error: errMsg, }) } - go pprof.Do(nCtx, pprof.Labels("actionID", actionID, "type", actionType), run) + go pprof.Do(ctx, pprof.Labels("actionID", actionID, "type", actionType), run) } func (r *Runner) sendJobsMessage(payload agentpb.AgentResponsePayload) { @@ -261,13 +405,25 @@ func (r *Runner) sendActionsMessage(payload agentpb.AgentRequestPayload) { } func (r *Runner) addCancel(jobID string, cancel context.CancelFunc) { - r.rw.Lock() - defer r.rw.Unlock() - r.rCancel[jobID] = cancel + r.cancelsM.Lock() + defer r.cancelsM.Unlock() + r.cancels[jobID] = cancel } func (r *Runner) removeCancel(jobID string) { - r.rw.Lock() - defer r.rw.Unlock() - delete(r.rCancel, jobID) + r.cancelsM.Lock() + defer r.cancelsM.Unlock() + delete(r.cancels, jobID) +} + +func (r *Runner) addStarted(actionID string) { + r.runningM.Lock() + defer r.runningM.Unlock() + r.running[actionID] = struct{}{} +} + +func (r *Runner) removeStarted(actionID string) { + r.runningM.Lock() + defer r.runningM.Unlock() + delete(r.running, actionID) } diff --git a/agent/runner/runner_test.go b/agent/runner/runner_test.go index a98f160164..76ba0ad887 100644 --- a/agent/runner/runner_test.go +++ b/agent/runner/runner_test.go @@ -42,7 +42,7 @@ func assertActionResults(t *testing.T, cr *Runner, expected ...*agentpb.ActionRe func TestConcurrentRunnerRun(t *testing.T) { t.Parallel() - cr := New(0) + cr := New(0, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -62,29 +62,33 @@ func TestConcurrentRunnerRun(t *testing.T) { } assertActionResults(t, cr, expected...) cr.wg.Wait() - assert.Empty(t, cr.rCancel) + assert.Empty(t, cr.cancels) } func TestCapacityLimit(t *testing.T) { t.Parallel() - cr := New(2) + cr := New(2, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go cr.Run(ctx) j1 := testJob{id: "test-1", timeout: time.Second} - j2 := testJob{id: "test-2", timeout: 2 * time.Second} + j2 := testJob{id: "test-2", timeout: time.Second} j3 := testJob{id: "test-3", timeout: 2 * time.Second} - j4 := testJob{id: "test-4", timeout: time.Second} + j4 := testJob{id: "test-4", timeout: 2 * time.Second} require.NoError(t, cr.StartJob(j1)) require.NoError(t, cr.StartJob(j2)) + + // Let first and second jobs start + time.Sleep(200 * time.Millisecond) + require.NoError(t, cr.StartJob(j3)) require.NoError(t, cr.StartJob(j4)) - // Let first jobs start - time.Sleep(500 * time.Millisecond) + // Let third and forth jobs to reach semaphores + time.Sleep(300 * time.Millisecond) // First two jobs are started assert.True(t, cr.IsRunning(j1.ID())) @@ -94,23 +98,15 @@ func TestCapacityLimit(t *testing.T) { time.Sleep(time.Second) - // After second first job terminated and third job started - assert.False(t, cr.IsRunning(j1.ID())) - assert.True(t, cr.IsRunning(j2.ID())) - assert.True(t, cr.IsRunning(j3.ID())) - assert.False(t, cr.IsRunning(j4.ID())) - - time.Sleep(time.Second) - // After one more second job terminated and third started assert.False(t, cr.IsRunning(j1.ID())) assert.False(t, cr.IsRunning(j2.ID())) assert.True(t, cr.IsRunning(j3.ID())) assert.True(t, cr.IsRunning(j4.ID())) - time.Sleep(time.Second) + time.Sleep(2 * time.Second) - // After another second all jobs are terminated + // After two seconds all jobs are terminated assert.False(t, cr.IsRunning(j1.ID())) assert.False(t, cr.IsRunning(j2.ID())) assert.False(t, cr.IsRunning(j3.ID())) @@ -121,28 +117,127 @@ func TestDefaultCapacityLimit(t *testing.T) { t.Parallel() // Use default capacity - cr := New(0) + cr := New(0, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go cr.Run(ctx) - totalJobs := 2 * defaultCapacity + totalJobs := 2 * defaultTotalCapacity for i := 0; i < totalJobs; i++ { require.NoError(t, cr.StartJob(testJob{id: fmt.Sprintf("test-%d", i), timeout: time.Second})) } - // Let first jobs start + // Let jobs to start time.Sleep(500 * time.Millisecond) + var running int for i := 0; i < totalJobs; i++ { // Check that running jobs amount is not exceeded default capacity. - assert.Equal(t, i < defaultCapacity, cr.IsRunning(fmt.Sprintf("test-%d", i))) + if cr.IsRunning(fmt.Sprintf("test-%d", i)) { + running++ + } } + + assert.Equal(t, defaultTotalCapacity, running) +} + +func TestPerDBInstanceLimit(t *testing.T) { + t.Parallel() + + cr := New(10, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go cr.Run(ctx) + + db1j1 := testJob{id: "test-1", timeout: time.Second, dsn: "postgresql://db1"} + db1j2 := testJob{id: "test-2", timeout: time.Second, dsn: "postgresql://db1"} + db1j3 := testJob{id: "test-3", timeout: time.Second, dsn: "postgresql://db1"} + db2j1 := testJob{id: "test-4", timeout: time.Second, dsn: "postgresql://db2"} + db2j2 := testJob{id: "test-5", timeout: time.Second, dsn: "postgresql://db2"} + db2j3 := testJob{id: "test-6", timeout: time.Second, dsn: "postgresql://db2"} + + require.NoError(t, cr.StartJob(db1j1)) + require.NoError(t, cr.StartJob(db2j1)) + + // Let jobs to start + time.Sleep(200 * time.Millisecond) + + require.NoError(t, cr.StartJob(db1j2)) + require.NoError(t, cr.StartJob(db2j2)) + require.NoError(t, cr.StartJob(db1j3)) + require.NoError(t, cr.StartJob(db2j3)) + + // Let rest jobs to reach semaphores + time.Sleep(300 * time.Millisecond) + + assert.True(t, cr.IsRunning(db1j1.ID())) + assert.True(t, cr.IsRunning(db2j1.ID())) + assert.False(t, cr.IsRunning(db1j2.ID())) + assert.False(t, cr.IsRunning(db2j2.ID())) + assert.False(t, cr.IsRunning(db1j3.ID())) + assert.False(t, cr.IsRunning(db2j3.ID())) + + // Over time all jobs are terminated + time.Sleep(2 * time.Second) + + assert.False(t, cr.IsRunning(db1j1.ID())) + assert.False(t, cr.IsRunning(db2j1.ID())) + assert.False(t, cr.IsRunning(db1j2.ID())) + assert.False(t, cr.IsRunning(db2j2.ID())) + assert.False(t, cr.IsRunning(db1j3.ID())) + assert.False(t, cr.IsRunning(db2j3.ID())) +} + +func TestDefaultPerDBInstanceLimit(t *testing.T) { + t.Parallel() + + cr := New(10, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go cr.Run(ctx) + + db1j1 := testJob{id: "test-1", timeout: time.Second, dsn: "postgresql://db1"} + db1j2 := testJob{id: "test-2", timeout: time.Second, dsn: "postgresql://db1"} + db1j3 := testJob{id: "test-3", timeout: time.Second, dsn: "postgresql://db1"} + db2j1 := testJob{id: "test-4", timeout: time.Second, dsn: "postgresql://db2"} + db2j2 := testJob{id: "test-5", timeout: time.Second, dsn: "postgresql://db2"} + db2j3 := testJob{id: "test-6", timeout: time.Second, dsn: "postgresql://db2"} + + require.NoError(t, cr.StartJob(db1j1)) + require.NoError(t, cr.StartJob(db2j1)) + require.NoError(t, cr.StartJob(db1j2)) + require.NoError(t, cr.StartJob(db2j2)) + + // Let jobs to start + time.Sleep(200 * time.Millisecond) + + require.NoError(t, cr.StartJob(db1j3)) + require.NoError(t, cr.StartJob(db2j3)) + + // Let rest jobs to reach semaphores + time.Sleep(300 * time.Millisecond) + + assert.True(t, cr.IsRunning(db1j1.ID())) + assert.True(t, cr.IsRunning(db2j1.ID())) + assert.True(t, cr.IsRunning(db1j2.ID())) + assert.True(t, cr.IsRunning(db2j2.ID())) + assert.False(t, cr.IsRunning(db1j3.ID())) + assert.False(t, cr.IsRunning(db2j3.ID())) + + // Over time all jobs are terminated + time.Sleep(2 * time.Second) + + assert.False(t, cr.IsRunning(db1j1.ID())) + assert.False(t, cr.IsRunning(db2j1.ID())) + assert.False(t, cr.IsRunning(db1j2.ID())) + assert.False(t, cr.IsRunning(db2j2.ID())) + assert.False(t, cr.IsRunning(db1j3.ID())) + assert.False(t, cr.IsRunning(db2j3.ID())) } func TestConcurrentRunnerTimeout(t *testing.T) { t.Parallel() - cr := New(0) + cr := New(0, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -163,12 +258,12 @@ func TestConcurrentRunnerTimeout(t *testing.T) { } assertActionResults(t, cr, expected...) cr.wg.Wait() - assert.Empty(t, cr.rCancel) + assert.Empty(t, cr.cancels) } func TestConcurrentRunnerStop(t *testing.T) { t.Parallel() - cr := New(0) + cr := New(0, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -194,12 +289,12 @@ func TestConcurrentRunnerStop(t *testing.T) { } assertActionResults(t, cr, expected...) cr.wg.Wait() - assert.Empty(t, cr.rCancel) + assert.Empty(t, cr.cancels) } func TestConcurrentRunnerCancel(t *testing.T) { t.Parallel() - cr := New(0) + cr := New(0, 0) ctx, cancel := context.WithCancel(context.Background()) go cr.Run(ctx) @@ -231,12 +326,46 @@ func TestConcurrentRunnerCancel(t *testing.T) { assert.Contains(t, []string{"signal: killed", context.Canceled.Error()}, expected[0].(*agentpb.ActionResultRequest).Error) assert.True(t, expected[1].(*agentpb.ActionResultRequest).Done) cr.wg.Wait() - assert.Empty(t, cr.rCancel) + assert.Empty(t, cr.cancels) +} + +func TestSemaphoresReleasing(t *testing.T) { + t.Parallel() + cr := New(1, 1) + err := cr.gSem.Acquire(context.TODO(), 1) // Acquire global semaphore to block all jobs + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go cr.Run(ctx) + + j := testJob{id: "test-1", timeout: time.Second, dsn: "test"} + + require.NoError(t, cr.StartJob(j)) + + // Let job to start + time.Sleep(200 * time.Millisecond) + + // Check that job is started and local semaphore was acquired + assert.Equal(t, cr.lSemsLen(), 1) + + // Check that job is not running, because it's waiting for global semaphore to be acquired + assert.False(t, cr.IsRunning(j.ID())) + + // Cancel context to stop job + cancel() + + // Let job to start and release resources + time.Sleep(200 * time.Millisecond) + + // Check that local samaphore was released + assert.Equal(t, cr.lSemsLen(), 0) } type testJob struct { id string timeout time.Duration + dsn string } func (t testJob) ID() string { @@ -251,6 +380,10 @@ func (t testJob) Timeout() time.Duration { return t.timeout } +func (t testJob) DSN() string { + return t.dsn +} + func (t testJob) Run(ctx context.Context, send jobs.Send) error { //nolint:revive <-ctx.Done() return nil