From 0476b00c70f9cad8fc4f4fdd44ec9c75174c2673 Mon Sep 17 00:00:00 2001 From: Clement Michaud Date: Mon, 22 Feb 2021 14:45:06 +0100 Subject: [PATCH] Spread the request context throughout up to db queries. --- cmd/go-graphkb/main.go | 10 ++--- internal/database/mariadb.go | 55 +++++++++++------------ internal/handlers/handler_update_graph.go | 25 ++++++----- internal/knowledge/graph_updater.go | 22 ++++----- internal/knowledge/graphdb.go | 16 +++---- internal/server/server.go | 8 ++-- 6 files changed, 68 insertions(+), 68 deletions(-) diff --git a/cmd/go-graphkb/main.go b/cmd/go-graphkb/main.go index f001cd1..458bdd3 100644 --- a/cmd/go-graphkb/main.go +++ b/cmd/go-graphkb/main.go @@ -82,7 +82,7 @@ func logLevelParamToSeverity(level string) logrus.Level { case "error": return logrus.ErrorLevel } - logrus.Fatal("Provided level %s is not a valid option") + logrus.Fatal("Provided level %s is not a valid option", level) // This should never be reached but needed by the compiler return logrus.InfoLevel } @@ -112,12 +112,12 @@ func onInit() { } func count(cmd *cobra.Command, args []string) { - countAssets, err := Database.CountAssets() + countAssets, err := Database.CountAssets(context.Background()) if err != nil { logrus.Fatal(err) } - countRelations, err := Database.CountRelations() + countRelations, err := Database.CountRelations(context.Background()) if err != nil { logrus.Fatal(err) } @@ -125,7 +125,7 @@ func count(cmd *cobra.Command, args []string) { } func flush(cmd *cobra.Command, args []string) { - if err := Database.FlushAll(); err != nil { + if err := Database.FlushAll(context.Background()); err != nil { logrus.Fatal(err) } logrus.Info("Successul flush") @@ -147,7 +147,7 @@ func listen(cmd *cobra.Command, args []string) { func read(cmd *cobra.Command, args []string) { g := knowledge.NewGraph() - err := Database.ReadGraph(args[0], g) + err := Database.ReadGraph(context.Background(), args[0], g) if err != nil { logrus.Fatal(err) } diff --git a/internal/database/mariadb.go b/internal/database/mariadb.go index 0484c42..accdf67 100644 --- a/internal/database/mariadb.go +++ b/internal/database/mariadb.go @@ -249,7 +249,7 @@ func hashRelation(relation knowledge.Relation) uint64 { } // InsertAssets insert multiple assets into the graph of the given source -func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error { +func (m *MariaDB) InsertAssets(ctx context.Context, source string, assets []knowledge.Asset) error { sourceID, err := m.resolveSourceID(source) if err != nil { return fmt.Errorf("Unable to resolve source ID of source %s for inserting assets: %v", source, err) @@ -263,7 +263,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error { for _, asset := range assets { h := hashAsset(asset) - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `INSERT INTO assets (id, type, value) VALUES (?, ?, ?)`, h, asset.Type, asset.Key) if err != nil { @@ -275,7 +275,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error { } } - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `INSERT INTO assets_by_source (source_id, asset_id) VALUES (?, ?)`, sourceID, h) if err != nil { if driverErr, ok := err.(*mysql.MySQLError); ok && driverErr.Number == mysqlerr.ER_DUP_ENTRY { @@ -294,7 +294,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error { } // InsertRelations upsert one relation into the graph of the given source -func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation) error { +func (m *MariaDB) InsertRelations(ctx context.Context, source string, relations []knowledge.Relation) error { sourceID, err := m.resolveSourceID(source) if err != nil { return fmt.Errorf("Unable to resolve source ID of source %s for inserting relations: %v", source, err) @@ -311,7 +311,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation) aTo := hashAsset(knowledge.Asset(relation.To)) rH := hashRelation(relation) - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, "INSERT INTO relations (id, from_id, to_id, type) VALUES (?, ?, ?, ?)", rH, aFrom, aTo, relation.Type) if err != nil { @@ -323,7 +323,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation) } } - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `INSERT INTO relations_by_source (source_id, relation_id) VALUES (?, ?)`, sourceID, rH) if err != nil { if driverErr, ok := err.(*mysql.MySQLError); ok && driverErr.Number == mysqlerr.ER_DUP_ENTRY { @@ -342,7 +342,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation) } // RemoveAssets remove one asset from the graph of the given source -func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error { +func (m *MariaDB) RemoveAssets(ctx context.Context, source string, assets []knowledge.Asset) error { sourceID, err := m.resolveSourceID(source) if err != nil { return fmt.Errorf("Unable to resolve source ID of source %s for removing assets: %v", source, err) @@ -356,7 +356,7 @@ func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error { for _, asset := range assets { h := hashAsset(asset) - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `DELETE FROM assets_by_source WHERE asset_id = ? AND source_id = ?`, h, sourceID) if err != nil { @@ -364,7 +364,7 @@ func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error { return fmt.Errorf("Unable to remove binding between asset %v (%d) and source %s: %v", asset, h, source, err) } - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `DELETE FROM assets WHERE id = ? AND NOT EXISTS ( SELECT * FROM assets_by_source WHERE asset_id = ? )`, @@ -383,7 +383,7 @@ func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error { } // RemoveRelations remove relations from the graph of the given source -func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation) error { +func (m *MariaDB) RemoveRelations(ctx context.Context, source string, relations []knowledge.Relation) error { sourceID, err := m.resolveSourceID(source) if err != nil { return fmt.Errorf("Unable to resolve source ID of source %s for removing relations: %v", source, err) @@ -397,7 +397,7 @@ func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation) for _, relation := range relations { rH := hashRelation(relation) - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `DELETE FROM relations_by_source WHERE relation_id = ? AND source_id = ?`, rH, sourceID) if err != nil { @@ -405,11 +405,10 @@ func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation) return fmt.Errorf("Unable to remove binding between relation %v (%d) and source %s: %v", relation, rH, source, err) } - _, err = tx.ExecContext(context.Background(), + _, err = tx.ExecContext(ctx, `DELETE FROM relations WHERE id = ? AND NOT EXISTS ( SELECT * FROM relations_by_source WHERE relation_id = ? - )`, - rH, rH) + )`, rH, rH) if err != nil { tx.Rollback() return fmt.Errorf("Unable to remove relation %v (%d) from source %s: %v", relation, rH, source, err) @@ -423,7 +422,7 @@ func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation) } // ReadGraph read source subgraph -func (m *MariaDB) ReadGraph(sourceName string, graph *knowledge.Graph) error { +func (m *MariaDB) ReadGraph(ctx context.Context, sourceName string, graph *knowledge.Graph) error { logrus.Debugf("Start reading graph of data source with name %s", sourceName) sourceID, err := m.resolveSourceID(sourceName) if err != nil { @@ -439,7 +438,7 @@ func (m *MariaDB) ReadGraph(sourceName string, graph *knowledge.Graph) error { { // Select all relations produced by this source - rows, err := tx.QueryContext(context.Background(), ` + rows, err := tx.QueryContext(ctx, ` SELECT a.type, a.value, b.type, b.value, r.type FROM relations_by_source rbs INNER JOIN relations r ON rbs.relation_id = r.id INNER JOIN assets a ON a.id=r.from_id @@ -477,7 +476,7 @@ WHERE rbs.source_id = ? { // Select all assets produced by this source. This is useful in case there are some standalone nodes in the graph of the source. // TODO(c.michaud): optimization could be done by only selecting assets without any relation since the others have already have been retrieved in the previous query. - rows, err := tx.QueryContext(context.Background(), ` + rows, err := tx.QueryContext(ctx, ` SELECT a.type, a.value FROM assets_by_source abs INNER JOIN assets a ON a.id=abs.asset_id WHERE abs.source_id = ? @@ -510,13 +509,13 @@ WHERE abs.source_id = ? } // FlushAll flush the database -func (m *MariaDB) FlushAll() error { +func (m *MariaDB) FlushAll(ctx context.Context) error { tx, err := m.db.Begin() if err != nil { return err } - _, err = tx.ExecContext(context.Background(), "DROP TABLE relations_by_source") + _, err = tx.ExecContext(ctx, "DROP TABLE relations_by_source") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -524,7 +523,7 @@ func (m *MariaDB) FlushAll() error { } } - _, err = tx.ExecContext(context.Background(), "DROP TABLE assets_by_source") + _, err = tx.ExecContext(ctx, "DROP TABLE assets_by_source") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -532,7 +531,7 @@ func (m *MariaDB) FlushAll() error { } } - _, err = tx.ExecContext(context.Background(), "DROP TABLE relations") + _, err = tx.ExecContext(ctx, "DROP TABLE relations") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -540,7 +539,7 @@ func (m *MariaDB) FlushAll() error { } } - _, err = tx.ExecContext(context.Background(), "DROP TABLE assets") + _, err = tx.ExecContext(ctx, "DROP TABLE assets") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -548,7 +547,7 @@ func (m *MariaDB) FlushAll() error { } } - _, err = tx.ExecContext(context.Background(), "DROP TABLE graph_schema") + _, err = tx.ExecContext(ctx, "DROP TABLE graph_schema") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -556,7 +555,7 @@ func (m *MariaDB) FlushAll() error { } } - _, err = tx.ExecContext(context.Background(), "DROP TABLE query_history") + _, err = tx.ExecContext(ctx, "DROP TABLE query_history") if err != nil { if !isUnknownTableError(err) { tx.Rollback() @@ -568,9 +567,9 @@ func (m *MariaDB) FlushAll() error { } // CountAssets count the total number of assets in db. -func (m *MariaDB) CountAssets() (int64, error) { +func (m *MariaDB) CountAssets(ctx context.Context) (int64, error) { var count int64 - row := m.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM assets") + row := m.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM assets") err := row.Scan(&count) if err != nil { @@ -580,9 +579,9 @@ func (m *MariaDB) CountAssets() (int64, error) { } // CountRelations count the total number of relations in db. -func (m *MariaDB) CountRelations() (int64, error) { +func (m *MariaDB) CountRelations(ctx context.Context) (int64, error) { var count int64 - row := m.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM relations") + row := m.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM relations") err := row.Scan(&count) if err != nil { diff --git a/internal/handlers/handler_update_graph.go b/internal/handlers/handler_update_graph.go index 415b115..220f0bb 100644 --- a/internal/handlers/handler_update_graph.go +++ b/internal/handlers/handler_update_graph.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "encoding/json" "fmt" "io" @@ -15,7 +16,7 @@ import ( "golang.org/x/sync/semaphore" ) -func handleUpdate(registry sources.Registry, fn func(source string, body io.Reader) error, sem *semaphore.Weighted, operationDescriptor string) http.HandlerFunc { +func handleUpdate(registry sources.Registry, fn func(ctx context.Context, source string, body io.Reader) error, sem *semaphore.Weighted, operationDescriptor string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ok, source, err := IsTokenValid(registry, r) if err != nil { @@ -51,7 +52,7 @@ func handleUpdate(registry sources.Registry, fn func(source string, body io.Read With(promLabels). Inc() - if err = fn(source, r.Body); err != nil { + if err = fn(r.Context(), source, r.Body); err != nil { metrics.GraphUpdateRequestsFailedCounter. With(promLabels). Inc() @@ -78,14 +79,14 @@ func handleUpdate(registry sources.Registry, fn func(source string, body io.Read // PutSchema upsert an asset into the graph of the data source func PutSchema(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc { - return handleUpdate(registry, func(source string, body io.Reader) error { + return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error { requestBody := client.PutGraphSchemaRequestBody{} if err := json.NewDecoder(body).Decode(&requestBody); err != nil { return err } // TODO(c.michaud): verify compatibility of the schema with graph updates - err := graphUpdater.UpdateSchema(source, requestBody.Schema) + err := graphUpdater.UpdateSchema(ctx, source, requestBody.Schema) if err != nil { return fmt.Errorf("Unable to update the schema: %v", err) } @@ -100,14 +101,14 @@ func PutSchema(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, // PutAssets upsert several assets into the graph of the data source func PutAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc { - return handleUpdate(registry, func(source string, body io.Reader) error { + return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error { requestBody := client.PutGraphAssetRequestBody{} if err := json.NewDecoder(body).Decode(&requestBody); err != nil { return err } // TODO(c.michaud): verify compatibility of the schema with graph updates - err := graphUpdater.InsertAssets(source, requestBody.Assets) + err := graphUpdater.InsertAssets(ctx, source, requestBody.Assets) if err != nil { return fmt.Errorf("Unable to insert assets: %v", err) } @@ -122,14 +123,14 @@ func PutAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, // PutRelations upsert multiple relations into the graph of the data source func PutRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc { - return handleUpdate(registry, func(source string, body io.Reader) error { + return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error { requestBody := client.PutGraphRelationRequestBody{} if err := json.NewDecoder(body).Decode(&requestBody); err != nil { return err } // TODO(c.michaud): verify compatibility of the schema with graph updates - err := graphUpdater.InsertRelations(source, requestBody.Relations) + err := graphUpdater.InsertRelations(ctx, source, requestBody.Relations) if err != nil { return fmt.Errorf("Unable to insert relation: %v", err) } @@ -144,14 +145,14 @@ func PutRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdate // DeleteAssets delete multiple assets from the graph of the data source func DeleteAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc { - return handleUpdate(registry, func(source string, body io.Reader) error { + return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error { requestBody := client.DeleteGraphAssetRequestBody{} if err := json.NewDecoder(body).Decode(&requestBody); err != nil { return err } // TODO(c.michaud): verify compatibility of the schema with graph updates - err := graphUpdater.RemoveAssets(source, requestBody.Assets) + err := graphUpdater.RemoveAssets(ctx, source, requestBody.Assets) if err != nil { return fmt.Errorf("Unable to remove assets: %v", err) } @@ -166,14 +167,14 @@ func DeleteAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdate // DeleteRelations remove multiple relations from the graph of the data source func DeleteRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc { - return handleUpdate(registry, func(source string, body io.Reader) error { + return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error { requestBody := client.DeleteGraphRelationRequestBody{} if err := json.NewDecoder(body).Decode(&requestBody); err != nil { return err } // TODO(c.michaud): verify compatibility of the schema with graph updates - err := graphUpdater.RemoveRelations(source, requestBody.Relations) + err := graphUpdater.RemoveRelations(ctx, source, requestBody.Relations) if err != nil { return fmt.Errorf("Unable to remove relation: %v", err) } diff --git a/internal/knowledge/graph_updater.go b/internal/knowledge/graph_updater.go index e1eef1b..47c16c9 100644 --- a/internal/knowledge/graph_updater.go +++ b/internal/knowledge/graph_updater.go @@ -27,8 +27,8 @@ func NewGraphUpdater(graphDB GraphDB, schemaPersistor schema.Persistor) *GraphUp } // UpdateSchema update the schema for the source with the one provided in the request -func (sl *GraphUpdater) UpdateSchema(source string, sg schema.SchemaGraph) error { - previousSchema, err := sl.schemaPersistor.LoadSchema(context.Background(), source) +func (sl *GraphUpdater) UpdateSchema(ctx context.Context, source string, sg schema.SchemaGraph) error { + previousSchema, err := sl.schemaPersistor.LoadSchema(ctx, source) if err != nil { return fmt.Errorf("Unable to read schema from DB: %v", err) } @@ -37,7 +37,7 @@ func (sl *GraphUpdater) UpdateSchema(source string, sg schema.SchemaGraph) error if !schemaEqual { logrus.Debug("The schema needs an update") - if err := sl.schemaPersistor.SaveSchema(context.Background(), source, sg); err != nil { + if err := sl.schemaPersistor.SaveSchema(ctx, source, sg); err != nil { return fmt.Errorf("Unable to write schema in DB: %v", err) } } @@ -45,32 +45,32 @@ func (sl *GraphUpdater) UpdateSchema(source string, sg schema.SchemaGraph) error } // InsertAssets insert multiple assets in the graph of the data source -func (sl *GraphUpdater) InsertAssets(source string, assets []Asset) error { - if err := sl.graphDB.InsertAssets(source, assets); err != nil { +func (sl *GraphUpdater) InsertAssets(ctx context.Context, source string, assets []Asset) error { + if err := sl.graphDB.InsertAssets(ctx, source, assets); err != nil { return fmt.Errorf("Unable to insert assets from source %s: %v", source, err) } return nil } // InsertRelations insert multiple relations in the graph of the data source -func (sl *GraphUpdater) InsertRelations(source string, relations []Relation) error { - if err := sl.graphDB.InsertRelations(source, relations); err != nil { +func (sl *GraphUpdater) InsertRelations(ctx context.Context, source string, relations []Relation) error { + if err := sl.graphDB.InsertRelations(ctx, source, relations); err != nil { return fmt.Errorf("Unable to insert relations from source %s: %v", source, err) } return nil } // RemoveAssets remove multiple assets from the graph of the data source -func (sl *GraphUpdater) RemoveAssets(source string, assets []Asset) error { - if err := sl.graphDB.RemoveAssets(source, assets); err != nil { +func (sl *GraphUpdater) RemoveAssets(ctx context.Context, source string, assets []Asset) error { + if err := sl.graphDB.RemoveAssets(ctx, source, assets); err != nil { return fmt.Errorf("Unable to remove assets from source %s: %v", source, err) } return nil } // RemoveRelations remove multiple relations from the graph of the data source -func (sl *GraphUpdater) RemoveRelations(source string, relations []Relation) error { - if err := sl.graphDB.RemoveRelations(source, relations); err != nil { +func (sl *GraphUpdater) RemoveRelations(ctx context.Context, source string, relations []Relation) error { + if err := sl.graphDB.RemoveRelations(ctx, source, relations); err != nil { return fmt.Errorf("Unable to remove relations from source %s: %v", source, err) } return nil diff --git a/internal/knowledge/graphdb.go b/internal/knowledge/graphdb.go index 9679238..9f1430c 100644 --- a/internal/knowledge/graphdb.go +++ b/internal/knowledge/graphdb.go @@ -19,18 +19,18 @@ type GraphDB interface { InitializeSchema() error - ReadGraph(source string, graph *Graph) error + ReadGraph(ctx context.Context, source string, graph *Graph) error // Atomic operations on the graph - InsertAssets(source string, assets []Asset) error - InsertRelations(source string, relations []Relation) error - RemoveAssets(source string, assets []Asset) error - RemoveRelations(source string, relations []Relation) error + InsertAssets(ctx context.Context, source string, assets []Asset) error + InsertRelations(ctx context.Context, source string, relations []Relation) error + RemoveAssets(ctx context.Context, source string, assets []Asset) error + RemoveRelations(ctx context.Context, source string, relations []Relation) error - FlushAll() error + FlushAll(ctx context.Context) error - CountAssets() (int64, error) - CountRelations() (int64, error) + CountAssets(ctx context.Context) (int64, error) + CountRelations(ctx context.Context) (int64, error) Query(ctx context.Context, query SQLTranslation) (*GraphQueryResult, error) } diff --git a/internal/server/server.go b/internal/server/server.go index 717b4c7..38ac2fa 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -94,13 +94,13 @@ func getDatabaseDetails(database knowledge.GraphDB) http.HandlerFunc { RelationsCount int64 `json:"relations_count"` } - assetsCount, err := database.CountAssets() + assetsCount, err := database.CountAssets(r.Context()) if err != nil { handlers.ReplyWithInternalError(w, err) return } - relationsCount, err := database.CountRelations() + relationsCount, err := database.CountRelations(r.Context()) if err != nil { handlers.ReplyWithInternalError(w, err) return @@ -131,7 +131,7 @@ func getGraphRead(registry sources.Registry, graphDB knowledge.GraphDB) http.Han } g := knowledge.NewGraph() - if err := graphDB.ReadGraph(source, g); err != nil { + if err := graphDB.ReadGraph(r.Context(), source, g); err != nil { handlers.ReplyWithInternalError(w, err) return } @@ -150,7 +150,7 @@ func getGraphRead(registry sources.Registry, graphDB knowledge.GraphDB) http.Han func flushDatabase(graphDB knowledge.GraphDB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if err := graphDB.FlushAll(); err != nil { + if err := graphDB.FlushAll(r.Context()); err != nil { handlers.ReplyWithInternalError(w, err) return }