Skip to content

Commit

Permalink
Spread the request context throughout up to db queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
clems4ever committed Feb 22, 2021
1 parent 2e1bdee commit 0476b00
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 68 deletions.
10 changes: 5 additions & 5 deletions cmd/go-graphkb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func logLevelParamToSeverity(level string) logrus.Level {
case "error":
return logrus.ErrorLevel
}
logrus.Fatal("Provided level %s is not a valid option")
logrus.Fatal("Provided level %s is not a valid option", level)
// This should never be reached but needed by the compiler
return logrus.InfoLevel
}
Expand Down Expand Up @@ -112,20 +112,20 @@ func onInit() {
}

func count(cmd *cobra.Command, args []string) {
countAssets, err := Database.CountAssets()
countAssets, err := Database.CountAssets(context.Background())
if err != nil {
logrus.Fatal(err)
}

countRelations, err := Database.CountRelations()
countRelations, err := Database.CountRelations(context.Background())
if err != nil {
logrus.Fatal(err)
}
fmt.Printf("%d assets\n%d relations\n", countAssets, countRelations)
}

func flush(cmd *cobra.Command, args []string) {
if err := Database.FlushAll(); err != nil {
if err := Database.FlushAll(context.Background()); err != nil {
logrus.Fatal(err)
}
logrus.Info("Successul flush")
Expand All @@ -147,7 +147,7 @@ func listen(cmd *cobra.Command, args []string) {

func read(cmd *cobra.Command, args []string) {
g := knowledge.NewGraph()
err := Database.ReadGraph(args[0], g)
err := Database.ReadGraph(context.Background(), args[0], g)
if err != nil {
logrus.Fatal(err)
}
Expand Down
55 changes: 27 additions & 28 deletions internal/database/mariadb.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func hashRelation(relation knowledge.Relation) uint64 {
}

// InsertAssets insert multiple assets into the graph of the given source
func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error {
func (m *MariaDB) InsertAssets(ctx context.Context, source string, assets []knowledge.Asset) error {
sourceID, err := m.resolveSourceID(source)
if err != nil {
return fmt.Errorf("Unable to resolve source ID of source %s for inserting assets: %v", source, err)
Expand All @@ -263,7 +263,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error {
for _, asset := range assets {
h := hashAsset(asset)

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`INSERT INTO assets (id, type, value) VALUES (?, ?, ?)`,
h, asset.Type, asset.Key)
if err != nil {
Expand All @@ -275,7 +275,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error {
}
}

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`INSERT INTO assets_by_source (source_id, asset_id) VALUES (?, ?)`, sourceID, h)
if err != nil {
if driverErr, ok := err.(*mysql.MySQLError); ok && driverErr.Number == mysqlerr.ER_DUP_ENTRY {
Expand All @@ -294,7 +294,7 @@ func (m *MariaDB) InsertAssets(source string, assets []knowledge.Asset) error {
}

// InsertRelations upsert one relation into the graph of the given source
func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation) error {
func (m *MariaDB) InsertRelations(ctx context.Context, source string, relations []knowledge.Relation) error {
sourceID, err := m.resolveSourceID(source)
if err != nil {
return fmt.Errorf("Unable to resolve source ID of source %s for inserting relations: %v", source, err)
Expand All @@ -311,7 +311,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation)
aTo := hashAsset(knowledge.Asset(relation.To))
rH := hashRelation(relation)

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
"INSERT INTO relations (id, from_id, to_id, type) VALUES (?, ?, ?, ?)",
rH, aFrom, aTo, relation.Type)
if err != nil {
Expand All @@ -323,7 +323,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation)
}
}

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`INSERT INTO relations_by_source (source_id, relation_id) VALUES (?, ?)`, sourceID, rH)
if err != nil {
if driverErr, ok := err.(*mysql.MySQLError); ok && driverErr.Number == mysqlerr.ER_DUP_ENTRY {
Expand All @@ -342,7 +342,7 @@ func (m *MariaDB) InsertRelations(source string, relations []knowledge.Relation)
}

// RemoveAssets remove one asset from the graph of the given source
func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error {
func (m *MariaDB) RemoveAssets(ctx context.Context, source string, assets []knowledge.Asset) error {
sourceID, err := m.resolveSourceID(source)
if err != nil {
return fmt.Errorf("Unable to resolve source ID of source %s for removing assets: %v", source, err)
Expand All @@ -356,15 +356,15 @@ func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error {
for _, asset := range assets {
h := hashAsset(asset)

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`DELETE FROM assets_by_source WHERE asset_id = ? AND source_id = ?`,
h, sourceID)
if err != nil {
tx.Rollback()
return fmt.Errorf("Unable to remove binding between asset %v (%d) and source %s: %v", asset, h, source, err)
}

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`DELETE FROM assets WHERE id = ? AND NOT EXISTS (
SELECT * FROM assets_by_source WHERE asset_id = ?
)`,
Expand All @@ -383,7 +383,7 @@ func (m *MariaDB) RemoveAssets(source string, assets []knowledge.Asset) error {
}

// RemoveRelations remove relations from the graph of the given source
func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation) error {
func (m *MariaDB) RemoveRelations(ctx context.Context, source string, relations []knowledge.Relation) error {
sourceID, err := m.resolveSourceID(source)
if err != nil {
return fmt.Errorf("Unable to resolve source ID of source %s for removing relations: %v", source, err)
Expand All @@ -397,19 +397,18 @@ func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation)
for _, relation := range relations {
rH := hashRelation(relation)

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`DELETE FROM relations_by_source WHERE relation_id = ? AND source_id = ?`,
rH, sourceID)
if err != nil {
tx.Rollback()
return fmt.Errorf("Unable to remove binding between relation %v (%d) and source %s: %v", relation, rH, source, err)
}

_, err = tx.ExecContext(context.Background(),
_, err = tx.ExecContext(ctx,
`DELETE FROM relations WHERE id = ? AND NOT EXISTS (
SELECT * FROM relations_by_source WHERE relation_id = ?
)`,
rH, rH)
)`, rH, rH)
if err != nil {
tx.Rollback()
return fmt.Errorf("Unable to remove relation %v (%d) from source %s: %v", relation, rH, source, err)
Expand All @@ -423,7 +422,7 @@ func (m *MariaDB) RemoveRelations(source string, relations []knowledge.Relation)
}

// ReadGraph read source subgraph
func (m *MariaDB) ReadGraph(sourceName string, graph *knowledge.Graph) error {
func (m *MariaDB) ReadGraph(ctx context.Context, sourceName string, graph *knowledge.Graph) error {
logrus.Debugf("Start reading graph of data source with name %s", sourceName)
sourceID, err := m.resolveSourceID(sourceName)
if err != nil {
Expand All @@ -439,7 +438,7 @@ func (m *MariaDB) ReadGraph(sourceName string, graph *knowledge.Graph) error {

{
// Select all relations produced by this source
rows, err := tx.QueryContext(context.Background(), `
rows, err := tx.QueryContext(ctx, `
SELECT a.type, a.value, b.type, b.value, r.type FROM relations_by_source rbs
INNER JOIN relations r ON rbs.relation_id = r.id
INNER JOIN assets a ON a.id=r.from_id
Expand Down Expand Up @@ -477,7 +476,7 @@ WHERE rbs.source_id = ?
{
// Select all assets produced by this source. This is useful in case there are some standalone nodes in the graph of the source.
// TODO(c.michaud): optimization could be done by only selecting assets without any relation since the others have already have been retrieved in the previous query.
rows, err := tx.QueryContext(context.Background(), `
rows, err := tx.QueryContext(ctx, `
SELECT a.type, a.value FROM assets_by_source abs
INNER JOIN assets a ON a.id=abs.asset_id
WHERE abs.source_id = ?
Expand Down Expand Up @@ -510,53 +509,53 @@ WHERE abs.source_id = ?
}

// FlushAll flush the database
func (m *MariaDB) FlushAll() error {
func (m *MariaDB) FlushAll(ctx context.Context) error {
tx, err := m.db.Begin()
if err != nil {
return err
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE relations_by_source")
_, err = tx.ExecContext(ctx, "DROP TABLE relations_by_source")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
return err
}
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE assets_by_source")
_, err = tx.ExecContext(ctx, "DROP TABLE assets_by_source")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
return err
}
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE relations")
_, err = tx.ExecContext(ctx, "DROP TABLE relations")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
return err
}
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE assets")
_, err = tx.ExecContext(ctx, "DROP TABLE assets")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
return err
}
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE graph_schema")
_, err = tx.ExecContext(ctx, "DROP TABLE graph_schema")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
return err
}
}

_, err = tx.ExecContext(context.Background(), "DROP TABLE query_history")
_, err = tx.ExecContext(ctx, "DROP TABLE query_history")
if err != nil {
if !isUnknownTableError(err) {
tx.Rollback()
Expand All @@ -568,9 +567,9 @@ func (m *MariaDB) FlushAll() error {
}

// CountAssets count the total number of assets in db.
func (m *MariaDB) CountAssets() (int64, error) {
func (m *MariaDB) CountAssets(ctx context.Context) (int64, error) {
var count int64
row := m.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM assets")
row := m.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM assets")

err := row.Scan(&count)
if err != nil {
Expand All @@ -580,9 +579,9 @@ func (m *MariaDB) CountAssets() (int64, error) {
}

// CountRelations count the total number of relations in db.
func (m *MariaDB) CountRelations() (int64, error) {
func (m *MariaDB) CountRelations(ctx context.Context) (int64, error) {
var count int64
row := m.db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM relations")
row := m.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM relations")

err := row.Scan(&count)
if err != nil {
Expand Down
25 changes: 13 additions & 12 deletions internal/handlers/handler_update_graph.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -15,7 +16,7 @@ import (
"golang.org/x/sync/semaphore"
)

func handleUpdate(registry sources.Registry, fn func(source string, body io.Reader) error, sem *semaphore.Weighted, operationDescriptor string) http.HandlerFunc {
func handleUpdate(registry sources.Registry, fn func(ctx context.Context, source string, body io.Reader) error, sem *semaphore.Weighted, operationDescriptor string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ok, source, err := IsTokenValid(registry, r)
if err != nil {
Expand Down Expand Up @@ -51,7 +52,7 @@ func handleUpdate(registry sources.Registry, fn func(source string, body io.Read
With(promLabels).
Inc()

if err = fn(source, r.Body); err != nil {
if err = fn(r.Context(), source, r.Body); err != nil {
metrics.GraphUpdateRequestsFailedCounter.
With(promLabels).
Inc()
Expand All @@ -78,14 +79,14 @@ func handleUpdate(registry sources.Registry, fn func(source string, body io.Read

// PutSchema upsert an asset into the graph of the data source
func PutSchema(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc {
return handleUpdate(registry, func(source string, body io.Reader) error {
return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error {
requestBody := client.PutGraphSchemaRequestBody{}
if err := json.NewDecoder(body).Decode(&requestBody); err != nil {
return err
}

// TODO(c.michaud): verify compatibility of the schema with graph updates
err := graphUpdater.UpdateSchema(source, requestBody.Schema)
err := graphUpdater.UpdateSchema(ctx, source, requestBody.Schema)
if err != nil {
return fmt.Errorf("Unable to update the schema: %v", err)
}
Expand All @@ -100,14 +101,14 @@ func PutSchema(registry sources.Registry, graphUpdater *knowledge.GraphUpdater,

// PutAssets upsert several assets into the graph of the data source
func PutAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc {
return handleUpdate(registry, func(source string, body io.Reader) error {
return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error {
requestBody := client.PutGraphAssetRequestBody{}
if err := json.NewDecoder(body).Decode(&requestBody); err != nil {
return err
}

// TODO(c.michaud): verify compatibility of the schema with graph updates
err := graphUpdater.InsertAssets(source, requestBody.Assets)
err := graphUpdater.InsertAssets(ctx, source, requestBody.Assets)
if err != nil {
return fmt.Errorf("Unable to insert assets: %v", err)
}
Expand All @@ -122,14 +123,14 @@ func PutAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater,

// PutRelations upsert multiple relations into the graph of the data source
func PutRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc {
return handleUpdate(registry, func(source string, body io.Reader) error {
return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error {
requestBody := client.PutGraphRelationRequestBody{}
if err := json.NewDecoder(body).Decode(&requestBody); err != nil {
return err
}

// TODO(c.michaud): verify compatibility of the schema with graph updates
err := graphUpdater.InsertRelations(source, requestBody.Relations)
err := graphUpdater.InsertRelations(ctx, source, requestBody.Relations)
if err != nil {
return fmt.Errorf("Unable to insert relation: %v", err)
}
Expand All @@ -144,14 +145,14 @@ func PutRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdate

// DeleteAssets delete multiple assets from the graph of the data source
func DeleteAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc {
return handleUpdate(registry, func(source string, body io.Reader) error {
return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error {
requestBody := client.DeleteGraphAssetRequestBody{}
if err := json.NewDecoder(body).Decode(&requestBody); err != nil {
return err
}

// TODO(c.michaud): verify compatibility of the schema with graph updates
err := graphUpdater.RemoveAssets(source, requestBody.Assets)
err := graphUpdater.RemoveAssets(ctx, source, requestBody.Assets)
if err != nil {
return fmt.Errorf("Unable to remove assets: %v", err)
}
Expand All @@ -166,14 +167,14 @@ func DeleteAssets(registry sources.Registry, graphUpdater *knowledge.GraphUpdate

// DeleteRelations remove multiple relations from the graph of the data source
func DeleteRelations(registry sources.Registry, graphUpdater *knowledge.GraphUpdater, sem *semaphore.Weighted) http.HandlerFunc {
return handleUpdate(registry, func(source string, body io.Reader) error {
return handleUpdate(registry, func(ctx context.Context, source string, body io.Reader) error {
requestBody := client.DeleteGraphRelationRequestBody{}
if err := json.NewDecoder(body).Decode(&requestBody); err != nil {
return err
}

// TODO(c.michaud): verify compatibility of the schema with graph updates
err := graphUpdater.RemoveRelations(source, requestBody.Relations)
err := graphUpdater.RemoveRelations(ctx, source, requestBody.Relations)
if err != nil {
return fmt.Errorf("Unable to remove relation: %v", err)
}
Expand Down
Loading

0 comments on commit 0476b00

Please sign in to comment.