From 5bd3749fc596d36a73d81b75ae7cb4f043578b6c Mon Sep 17 00:00:00 2001 From: Clement Michaud Date: Thu, 9 Jan 2020 14:40:18 +0100 Subject: [PATCH] Create public interface by exporting internal types. And also pass options in Start instead of relying on viper in the implementation. --- .gitignore | 2 + cmd/importer-csv/main.go | 9 +- graphkb/graph_importer.go | 2 + graphkb/importer.go | 30 +++- graphkb/tasks.go | 2 + internal/database/mariadb.go | 214 +++++++++++++---------- internal/knowledge/graph_api.go | 7 +- internal/knowledge/graph_updater.go | 19 +- internal/knowledge/graph_updates.go | 133 ++++++++++++-- internal/knowledge/graph_updates_test.go | 74 ++++---- internal/knowledge/transaction.go | 2 +- internal/schema/graph.go | 2 +- internal/server/server.go | 2 +- internal/utils/task.go | 37 ++-- 14 files changed, 348 insertions(+), 187 deletions(-) diff --git a/.gitignore b/.gitignore index 5997da4..5fde4f3 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ *.tar.gz go-graphkb importer-csv + +.config.yml diff --git a/cmd/importer-csv/main.go b/cmd/importer-csv/main.go index 71f0083..eab1658 100644 --- a/cmd/importer-csv/main.go +++ b/cmd/importer-csv/main.go @@ -98,7 +98,14 @@ func main() { rootCmd := &cobra.Command{ Use: "source-csv [opts]", Run: func(cmd *cobra.Command, args []string) { - if err := graphkb.Start(NewCSVSource(), nil); err != nil { + + options := graphkb.ImporterOptions{ + URL: viper.GetString("graphkb.url"), + AuthToken: viper.GetString("graphkb.auth_token"), + SkipVerify: viper.GetBool("graphkb.skip_verify"), + } + + if err := graphkb.Start(NewCSVSource(), options); err != nil { log.Fatal(err) } }, diff --git a/graphkb/graph_importer.go b/graphkb/graph_importer.go index 77fe757..948a8f8 100644 --- a/graphkb/graph_importer.go +++ b/graphkb/graph_importer.go @@ -3,3 +3,5 @@ package graphkb import "github.com/clems4ever/go-graphkb/internal/knowledge" type GraphImporter = knowledge.GraphImporter + +type Transaction = knowledge.Transaction diff --git a/graphkb/importer.go b/graphkb/importer.go index 200915f..33e54dd 100644 --- a/graphkb/importer.go +++ b/graphkb/importer.go @@ -2,28 +2,28 @@ package graphkb import ( "fmt" + "github.com/clems4ever/go-graphkb/internal/knowledge" + "github.com/clems4ever/go-graphkb/internal/schema" "github.com/clems4ever/go-graphkb/internal/sources" - "github.com/spf13/viper" ) type ImporterOptions struct { - CacheGraph bool + URL string + AuthToken string + SkipVerify bool } -func Start(source sources.Source, options *ImporterOptions) error { - url := viper.GetString("graphkb.url") - if url == "" { +func Start(source sources.Source, options ImporterOptions) error { + if options.URL == "" { return fmt.Errorf("Please provide graphkb URL in configuration file") } - - authToken := viper.GetString("graphkb.auth_token") - if authToken == "" { + if options.AuthToken == "" { return fmt.Errorf("Please provide a graphkb auth token to communicate with GraphKB") } observableSource := sources.NewObservableSource(source) - api := knowledge.NewGraphAPI(url, authToken) + api := knowledge.NewGraphAPI(options.URL, options.AuthToken, options.SkipVerify) graphImporter := knowledge.NewGraphImporter(api) if err := observableSource.Start(graphImporter); err != nil { @@ -32,3 +32,15 @@ func Start(source sources.Source, options *ImporterOptions) error { return nil } + +func CreateRelation(fromType schema.AssetType, relation, toType schema.AssetType) RelationType { + return schema.RelationType{ + FromType: fromType, + Type: RelationKeyType(relation), + ToType: toType, + } +} + +func CreateAsset(fromType string) AssetType { + return schema.AssetType(fromType) +} diff --git a/graphkb/tasks.go b/graphkb/tasks.go index f476fc3..fff806e 100644 --- a/graphkb/tasks.go +++ b/graphkb/tasks.go @@ -3,3 +3,5 @@ package graphkb import "github.com/clems4ever/go-graphkb/internal/utils" type RecurrentTask = utils.RecurrentTask + +var NewRecurrentTask = utils.NewRecurrentTask diff --git a/internal/database/mariadb.go b/internal/database/mariadb.go index 9a1cd9a..6bc786d 100644 --- a/internal/database/mariadb.go +++ b/internal/database/mariadb.go @@ -15,6 +15,7 @@ import ( "github.com/clems4ever/go-graphkb/internal/query" "github.com/clems4ever/go-graphkb/internal/schema" "github.com/clems4ever/go-graphkb/internal/utils" + mapset "github.com/deckarep/golang-set" mysql "github.com/go-sql-driver/mysql" "github.com/golang-collections/go-datastructures/queue" ) @@ -43,7 +44,7 @@ func NewMariaDB(username string, password string, host string, databaseName stri // InitializeSchema initialize the schema of the database func (m *MariaDB) InitializeSchema() error { // type must be part of the primary key to be a partition key - _, err := m.db.QueryContext(context.Background(), ` + q, err := m.db.QueryContext(context.Background(), ` CREATE TABLE IF NOT EXISTS assets ( id INT NOT NULL AUTO_INCREMENT, value VARCHAR(255) NOT NULL, @@ -56,9 +57,10 @@ CREATE TABLE IF NOT EXISTS assets ( if err != nil { return err } + defer q.Close() // type must be part of the primary key to be a partition key - _, err = m.db.QueryContext(context.Background(), ` + q, err = m.db.QueryContext(context.Background(), ` CREATE TABLE IF NOT EXISTS relations ( id INT NOT NULL AUTO_INCREMENT, from_id INT NOT NULL, @@ -69,21 +71,20 @@ CREATE TABLE IF NOT EXISTS relations ( CONSTRAINT pk_relation PRIMARY KEY (id), CONSTRAINT fk_from FOREIGN KEY (from_id) REFERENCES assets (id), CONSTRAINT fk_to FOREIGN KEY (to_id) REFERENCES assets (id), - - INDEX full_relation_type_from_to_idx (type, from_id, to_id), - INDEX full_relation_type_to_from_idx (type, to_id, from_id), - - INDEX full_relation_from_type_to_idx (from_id, type, to_id), - INDEX full_relation_from_to_type_idx (from_id, to_id, type), - INDEX full_relation_to_from_type_idx (to_id, from_id, type), - INDEX full_relation_to_type_from_idx (to_id, type, from_id))`) + INDEX full_relation_type_from_to_idx (type, from_id, to_id), + INDEX full_relation_type_to_from_idx (type, to_id, from_id), + INDEX full_relation_from_type_to_idx (from_id, type, to_id), + INDEX full_relation_from_to_type_idx (from_id, to_id, type), + INDEX full_relation_to_from_type_idx (to_id, from_id, type), + INDEX full_relation_to_type_from_idx (to_id, type, from_id))`) if err != nil { return err } + defer q.Close() // Create the table storing the schema graphs - _, err = m.db.QueryContext(context.Background(), ` + q, err = m.db.QueryContext(context.Background(), ` CREATE TABLE IF NOT EXISTS graph_schema ( id INTEGER AUTO_INCREMENT NOT NULL, source_name VARCHAR(64) NOT NULL, @@ -93,48 +94,22 @@ CONSTRAINT pk_schema PRIMARY KEY (id))`) if err != nil { return err } - + defer q.Close() return nil } // AssetIDResolver store ID assets in a cache -type AssetIDResolver struct { +type AssetRegistry struct { cache map[knowledge.AssetKey]int64 - db *sql.DB } -func (c *AssetIDResolver) Set(a knowledge.AssetKey, idx int64) { - c.cache[a] = idx +func (ar *AssetRegistry) Set(a knowledge.AssetKey, idx int64) { + ar.cache[a] = idx } -func (c *AssetIDResolver) Get(a knowledge.AssetKey) (int64, bool, error) { - idx, ok := c.cache[a] - - if ok { - return idx, true, nil - } - - if !ok { - q, err := c.db.PrepareContext(context.Background(), ` -SELECT id FROM assets WHERE type = ? AND value = ?`) - if err != nil { - return 0, false, fmt.Errorf("Unable to prepare asset select query: %v", err) - } - - res, err := q.QueryContext(context.Background(), a.Type, a.Key) - if err != nil { - return 0, false, err - } - - for res.Next() { - if err := res.Scan(&idx); err != nil { - return 0, false, err - } - c.cache[a] = idx - return idx, true, nil - } - } - return 0, false, nil +func (ar *AssetRegistry) Get(a knowledge.AssetKey) (int64, bool) { + idx, ok := ar.cache[a] + return idx, ok } func isDuplicateEntryError(err error) bool { @@ -147,15 +122,60 @@ func isUnknownTableError(err error) bool { return ok && driverErr.Number == 1051 } -func (m *MariaDB) upsertAssets(assets []knowledge.Asset, assetResolver *AssetIDResolver) (int64, error) { +func (m *MariaDB) resolveAssets(assets []knowledge.AssetKey, registry *AssetRegistry) error { + bar := pb.StartNew(len(assets)) + defer bar.Finish() + + tx, err := m.db.Begin() + if err != nil { + return err + } + + stmt, err := tx.PrepareContext(context.Background(), "SELECT id FROM assets WHERE type = ? AND value = ?") + if err != nil { + return err + } + + for _, a := range assets { + q, err := stmt.QueryContext(context.Background(), a.Type, a.Key) + if err != nil { + return err + } + defer q.Close() + + for q.Next() { + var idx int64 + if err := q.Scan(&idx); err != nil { + return err + } + registry.Set(knowledge.AssetKey(a), idx) + } + + bar.Increment() + } + return tx.Commit() +} + +func (m *MariaDB) upsertAssets(assets []knowledge.Asset, registry *AssetRegistry) (int64, error) { if len(assets) == 0 { return 0, nil } - bar := pb.StartNew(len(assets)) + + unresolved := []knowledge.Asset{} + for _, a := range assets { + _, ok := registry.Get(knowledge.AssetKey(a)) + if !ok { + unresolved = append(unresolved, a) + } + } + + bar := pb.StartNew(len(unresolved)) + defer bar.Finish() insertedCount := int64(0) - assetChunks := utils.ChunkSlice(assets, 10000).([][]interface{}) + assetChunks := utils.ChunkSlice(unresolved, 10000).([][]interface{}) + // A chunk of assets to store in a transaction for _, assetChunk := range assetChunks { tx, err := m.db.Begin() if err != nil { @@ -163,7 +183,7 @@ func (m *MariaDB) upsertAssets(assets []knowledge.Asset, assetResolver *AssetIDR } insertQuery, err := tx.PrepareContext(context.Background(), ` -REPLACE INTO assets (type, value) VALUES (?, ?)`) +INSERT INTO assets (type, value) VALUES (?, ?)`) if err != nil { log.Fatal(fmt.Errorf("Unable to prepare asset insertion query: %v", err)) } @@ -171,28 +191,18 @@ REPLACE INTO assets (type, value) VALUES (?, ?)`) for _, aC := range assetChunk { a := aC.(knowledge.Asset) - _, found, err := assetResolver.Get(knowledge.AssetKey(a)) + res, err := insertQuery.ExecContext(context.Background(), a.Type, a.Key) + if err != nil { + return 0, fmt.Errorf("Unable to insert asset %v: %v", a, err) + } + idx, err := res.LastInsertId() if err != nil { return 0, err } + registry.Set(knowledge.AssetKey(a), idx) - if !found { - res, err := insertQuery.ExecContext(context.Background(), a.Type, a.Key) - if err != nil { - if !isDuplicateEntryError(err) { - log.Fatal(fmt.Errorf("Unable to insert asset %v: %v", a, err)) - return 0, err - } - } - idx, err := res.LastInsertId() - if err != nil { - return 0, err - } - assetResolver.Set(knowledge.AssetKey(a), idx) - - atomic.AddInt64(&insertedCount, 1) - bar.Increment() - } + atomic.AddInt64(&insertedCount, 1) + bar.Increment() } err = tx.Commit() @@ -200,19 +210,18 @@ REPLACE INTO assets (type, value) VALUES (?, ?)`) log.Fatal(err) } } - - bar.Finish() return insertedCount, nil } -func (m *MariaDB) upsertRelations(source string, relations []knowledge.Relation, assetResolver *AssetIDResolver) (int64, error) { +func (m *MariaDB) upsertRelations(source string, relations []knowledge.Relation, registry *AssetRegistry) (int64, error) { if len(relations) == 0 { return 0, nil } bar := pb.StartNew(len(relations)) + defer bar.Finish() insertedCount := int64(0) - relationChunks := utils.ChunkSlice(relations, 1000).([][]interface{}) + relationChunks := utils.ChunkSlice(relations, 10000).([][]interface{}) for _, relationChunk := range relationChunks { tx, err := m.db.Begin() @@ -225,24 +234,17 @@ func (m *MariaDB) upsertRelations(source string, relations []knowledge.Relation, if err != nil { log.Fatal(fmt.Errorf("Unable to prepare relation insertion query: %v", err)) } + defer q.Close() for _, rC := range relationChunk { r := rC.(knowledge.Relation) - idxFrom, ok, err := assetResolver.Get(r.From) - if err != nil { - return 0, err - } - + idxFrom, ok := registry.Get(r.From) if !ok { fmt.Printf("[WARNING] ID of asset %v (from) has not been found in cache\n", r.From) continue } - idxTo, ok, err := assetResolver.Get(r.To) - if err != nil { - return 0, err - } - + idxTo, ok := registry.Get(r.To) if !ok { fmt.Printf("[WARNING] ID of asset %v (to) has not been found in cache\n", r.To) continue @@ -254,7 +256,7 @@ func (m *MariaDB) upsertRelations(source string, relations []knowledge.Relation, bar.Increment() continue } - log.Fatal(fmt.Errorf("Unable to insert relation %v: %v", r, err)) + log.Fatal(fmt.Errorf("Unable to insert relation %v (%d -> %d): %v", r, idxFrom, idxTo, err)) } bar.Increment() insertedCount++ @@ -288,6 +290,7 @@ WHERE a.type = ? AND a.value = ? AND b.type = ? AND b.value = ? AND r.type = ?`) if err != nil { return 0, 0, err } + defer stmt.Close() for _, r := range relations { rel := SourceRelation{ @@ -316,10 +319,6 @@ AND id NOT IN (select to_id from relations)`) return 0, 0, err } - if err != nil { - return 0, 0, err - } - removedAssetsCount, err := res.RowsAffected() if err != nil { return 0, 0, err @@ -334,24 +333,54 @@ AND id NOT IN (select to_id from relations)`) // UpdateGraph update graph with bulk of operations func (m *MariaDB) UpdateGraph(source string, bulk *knowledge.GraphUpdatesBulk) error { - cache := AssetIDResolver{cache: make(map[knowledge.AssetKey]int64), db: m.db} + registry := AssetRegistry{cache: make(map[knowledge.AssetKey]int64)} + now := time.Now() - count, err := m.upsertAssets(bulk.AssetUpserts, &cache) + assetKeysSet := mapset.NewSet() + for _, a := range bulk.GetAssetUpserts() { + assetKeysSet.Add(knowledge.AssetKey(a)) + } + for _, r := range bulk.GetRelationUpserts() { + assetKeysSet.Add(r.From) + assetKeysSet.Add(r.To) + } + + assetKeys := []knowledge.AssetKey{} + for a := range assetKeysSet.Iter() { + assetKeys = append(assetKeys, a.(knowledge.AssetKey)) + } + + fmt.Println("Start resolving assets") + err := m.resolveAssets(assetKeys, ®istry) if err != nil { return err } - fmt.Printf("%d assets inserted\n", count) - count, err = m.upsertRelations(source, bulk.RelationUpserts, &cache) + + fmt.Println("Start upserting assets") + count, err := m.upsertAssets(bulk.GetAssetUpserts(), ®istry) + if err != nil { + return err + } + + nowAssetInsert := time.Now() + fmt.Printf("%d assets inserted in %fs\n", count, nowAssetInsert.Sub(now).Seconds()) + + fmt.Println("Start upserting relations") + count, err = m.upsertRelations(source, bulk.GetRelationUpserts(), ®istry) if err != nil { return err } - fmt.Printf("%d relations inserted\n", count) + nowRelationInsert := time.Now() + fmt.Printf("%d relations inserted in %fs\n", count, nowRelationInsert.Sub(nowAssetInsert).Seconds()) - relCount, assetsCount, err := m.removeRelations(source, bulk.RelationRemovals) + relCount, assetsCount, err := m.removeRelations(source, bulk.GetRelationRemovals()) if err != nil { return err } - fmt.Printf("%d assets removed\n%d relations removed\n", assetsCount, relCount) + fmt.Printf("%d assets removed and %d relations removed in %fs\n", + assetsCount, + relCount, + time.Since(nowRelationInsert).Seconds()) return nil } @@ -515,6 +544,7 @@ func (m *MariaDB) ListSources(ctx context.Context) ([]string, error) { if err != nil { return nil, fmt.Errorf("Unable to read sources from database: %v", err) } + defer rows.Close() sources := make([]string, 0) for rows.Next() { diff --git a/internal/knowledge/graph_api.go b/internal/knowledge/graph_api.go index cebe322..c305cf3 100644 --- a/internal/knowledge/graph_api.go +++ b/internal/knowledge/graph_api.go @@ -9,7 +9,6 @@ import ( "net/http" "github.com/clems4ever/go-graphkb/internal/schema" - "github.com/spf13/viper" ) // GraphEmitter an emitter of full source graph @@ -22,9 +21,9 @@ type GraphAPI struct { } // NewGraphEmitter create an emitter of graph -func NewGraphAPI(url string, authToken string) *GraphAPI { +func NewGraphAPI(url, authToken string, skipVerify bool) *GraphAPI { tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: viper.GetBool("graphkb.skip_verify")}, + TLSClientConfig: &tls.Config{InsecureSkipVerify: skipVerify}, } client := &http.Client{Transport: tr} @@ -69,7 +68,7 @@ func (gapi *GraphAPI) ReadCurrentGraph() (*Graph, error) { func (gapi *GraphAPI) UpdateGraph(sg schema.SchemaGraph, updates GraphUpdatesBulk) error { requestBody := GraphUpdateRequestBody{} - requestBody.Updates = updates + requestBody.Updates = &updates requestBody.Schema = sg b, err := json.Marshal(requestBody) diff --git a/internal/knowledge/graph_updater.go b/internal/knowledge/graph_updater.go index baa7cb1..6fe1766 100644 --- a/internal/knowledge/graph_updater.go +++ b/internal/knowledge/graph_updater.go @@ -31,7 +31,7 @@ func (sl *GraphUpdater) appendObservedRelations(source string, updates *GraphUpd observedRelationsToRemove := []Relation{} observedRelationsToAdd := []Relation{} - for _, a := range updates.AssetUpserts { + for _, a := range updates.GetAssetUpserts() { observedRelationsToAdd = append(observedRelationsToAdd, Relation{ Type: "observed", From: AssetKey(assetsToAdd[0]), @@ -39,7 +39,7 @@ func (sl *GraphUpdater) appendObservedRelations(source string, updates *GraphUpd }) } - for _, a := range updates.AssetRemovals { + for _, a := range updates.GetAssetRemovals() { observedRelationsToRemove = append(observedRelationsToRemove, Relation{ Type: "observed", From: AssetKey(assetsToAdd[0]), @@ -47,9 +47,9 @@ func (sl *GraphUpdater) appendObservedRelations(source string, updates *GraphUpd }) } - updates.AssetUpserts = append(updates.AssetUpserts, assetsToAdd...) - updates.RelationUpserts = append(updates.RelationUpserts, observedRelationsToAdd...) - updates.RelationRemovals = append(updates.RelationRemovals, observedRelationsToRemove...) + updates.UpsertAssets(assetsToAdd...) + updates.UpsertRelations(observedRelationsToAdd...) + updates.RemoveRelations(observedRelationsToRemove...) } func (sl *GraphUpdater) updateSchema(source string, sg *schema.SchemaGraph) error { @@ -85,8 +85,15 @@ func (sl *GraphUpdater) doUpdate(updates SourceSubGraphUpdates) error { sl.appendObservedRelations(updates.Source, &updates.Updates) + fmt.Printf("Start updating the graph with:\n"+ + "\t%d assets to insert\n"+ + "\t%d assets to remove\n"+ + "\t%d relations to add\n"+ + "\t%d relations to remove\n", + len(updates.Updates.GetAssetUpserts()), len(updates.Updates.GetAssetRemovals()), + len(updates.Updates.GetRelationUpserts()), len(updates.Updates.GetAssetRemovals())) if err := sl.graphDB.UpdateGraph(updates.Source, &updates.Updates); err != nil { - fmt.Printf("[ERROR] Unable to write schema in graph DB: %v\n", err) + fmt.Printf("[ERROR] Unable to write data in graph DB: %v\n", err) return err } return nil diff --git a/internal/knowledge/graph_updates.go b/internal/knowledge/graph_updates.go index 4b62444..cfc4d0c 100644 --- a/internal/knowledge/graph_updates.go +++ b/internal/knowledge/graph_updates.go @@ -1,7 +1,21 @@ package knowledge +import ( + "encoding/json" + + mapset "github.com/deckarep/golang-set" +) + // GraphUpdatesBulk represent a bulk of asset and relation updates to perform on the graph type GraphUpdatesBulk struct { + assetUpserts mapset.Set + assetRemovals mapset.Set + relationUpserts mapset.Set + relationRemovals mapset.Set +} + +// GraphUpdatesBulkJSON represent a bulk in JSON form +type GraphUpdatesBulkJSON struct { AssetUpserts []Asset `json:"asset_upserts"` AssetRemovals []Asset `json:"asset_removals"` RelationUpserts []Relation `json:"relation_upserts"` @@ -11,51 +25,137 @@ type GraphUpdatesBulk struct { // NewGraphUpdatesBulk create an instance of graph updates func NewGraphUpdatesBulk() *GraphUpdatesBulk { return &GraphUpdatesBulk{ - AssetUpserts: make([]Asset, 0), - AssetRemovals: make([]Asset, 0), - RelationUpserts: make([]Relation, 0), - RelationRemovals: make([]Relation, 0), + assetUpserts: mapset.NewSet(), + assetRemovals: mapset.NewSet(), + relationUpserts: mapset.NewSet(), + relationRemovals: mapset.NewSet(), + } +} + +func (gub *GraphUpdatesBulk) Clear() { + gub.assetUpserts.Clear() + gub.assetRemovals.Clear() + gub.relationUpserts.Clear() + gub.relationRemovals.Clear() +} + +func (gub *GraphUpdatesBulk) GetAssetUpserts() []Asset { + assets := []Asset{} + for v := range gub.assetUpserts.Iter() { + assets = append(assets, v.(Asset)) } + return assets +} + +func (gub *GraphUpdatesBulk) HasAssetUpsert(asset Asset) bool { + return gub.assetUpserts.Contains(asset) } // UpsertAsset create an operation to upsert an asset func (gub *GraphUpdatesBulk) UpsertAsset(asset Asset) { - gub.AssetUpserts = append(gub.AssetUpserts, asset) + gub.assetUpserts.Add(asset) } // UpsertAssets append multiple assets to upsert -func (gub *GraphUpdatesBulk) UpsertAssets(asset ...Asset) { - gub.AssetUpserts = append(gub.AssetUpserts, asset...) +func (gub *GraphUpdatesBulk) UpsertAssets(assets ...Asset) { + for _, a := range assets { + gub.assetUpserts.Add(a) + } +} + +func (gub *GraphUpdatesBulk) GetAssetRemovals() []Asset { + assets := []Asset{} + for v := range gub.assetRemovals.Iter() { + assets = append(assets, v.(Asset)) + } + return assets +} + +func (gub *GraphUpdatesBulk) HasAssetRemoval(asset Asset) bool { + return gub.assetRemovals.Contains(asset) } // RemoveAsset create an operation to remove an asset func (gub *GraphUpdatesBulk) RemoveAsset(asset Asset) { - gub.AssetRemovals = append(gub.AssetRemovals, asset) + gub.assetRemovals.Add(asset) } // RemoveAssets create multiple asset removal operations -func (gub *GraphUpdatesBulk) RemoveAssets(asset ...Asset) { - gub.AssetRemovals = append(gub.AssetRemovals, asset...) +func (gub *GraphUpdatesBulk) RemoveAssets(assets ...Asset) { + for _, a := range assets { + gub.assetRemovals.Add(a) + } +} + +func (gub *GraphUpdatesBulk) GetRelationUpserts() []Relation { + relations := []Relation{} + for v := range gub.relationUpserts.Iter() { + relations = append(relations, v.(Relation)) + } + return relations +} + +func (gub *GraphUpdatesBulk) HasRelationUpsert(relation Relation) bool { + return gub.relationUpserts.Contains(relation) } // UpsertRelation create an operation to upsert an relation func (gub *GraphUpdatesBulk) UpsertRelation(relation Relation) { - gub.RelationUpserts = append(gub.RelationUpserts, relation) + gub.relationUpserts.Add(relation) } // UpsertRelations create multiple relation upsert operations -func (gub *GraphUpdatesBulk) UpsertRelations(relation ...Relation) { - gub.RelationUpserts = append(gub.RelationUpserts, relation...) +func (gub *GraphUpdatesBulk) UpsertRelations(relations ...Relation) { + for _, r := range relations { + gub.relationUpserts.Add(r) + } +} + +func (gub *GraphUpdatesBulk) GetRelationRemovals() []Relation { + relations := []Relation{} + for v := range gub.relationRemovals.Iter() { + relations = append(relations, v.(Relation)) + } + return relations +} + +func (gub *GraphUpdatesBulk) HasRelationRemoval(relation Relation) bool { + return gub.relationRemovals.Contains(relation) } // RemoveRelation create an operation to remove a relation func (gub *GraphUpdatesBulk) RemoveRelation(relation Relation) { - gub.RelationRemovals = append(gub.RelationRemovals, relation) + gub.relationRemovals.Add(relation) } // RemoveRelations create multiple relation removal operations -func (gub *GraphUpdatesBulk) RemoveRelations(relation ...Relation) { - gub.RelationRemovals = append(gub.RelationRemovals, relation...) +func (gub *GraphUpdatesBulk) RemoveRelations(relations ...Relation) { + for _, r := range relations { + gub.relationRemovals.Add(r) + } +} + +func (gub *GraphUpdatesBulk) MarshalJSON() ([]byte, error) { + j := &GraphUpdatesBulkJSON{} + j.AssetUpserts = gub.GetAssetUpserts() + j.AssetRemovals = gub.GetAssetRemovals() + j.RelationUpserts = gub.GetRelationUpserts() + j.RelationRemovals = gub.GetRelationRemovals() + return json.Marshal(j) +} + +func (gub *GraphUpdatesBulk) UnmarshalJSON(bytes []byte) error { + j := GraphUpdatesBulkJSON{} + if err := json.Unmarshal(bytes, &j); err != nil { + return err + } + + *gub = *NewGraphUpdatesBulk() + gub.UpsertAssets(j.AssetUpserts...) + gub.UpsertRelations(j.RelationUpserts...) + gub.RemoveAssets(j.AssetRemovals...) + gub.RemoveRelations(j.RelationRemovals...) + return nil } // GenerateGraphUpdatesBulk generate a graph update bulk by taking the difference between new graph @@ -98,5 +198,6 @@ func GenerateGraphUpdatesBulk(previousGraph *Graph, newGraph *Graph) *GraphUpdat bulk.RemoveRelation(r) } } + return bulk } diff --git a/internal/knowledge/graph_updates_test.go b/internal/knowledge/graph_updates_test.go index 6b86422..8cbb7bb 100644 --- a/internal/knowledge/graph_updates_test.go +++ b/internal/knowledge/graph_updates_test.go @@ -19,13 +19,13 @@ func (s *SourceUpdatesSuite) TestShouldUpsertForCreatingGraph() { bulk := GenerateGraphUpdatesBulk(nil, g) - s.Require().Len(bulk.AssetUpserts, 2) - s.Require().Len(bulk.RelationUpserts, 1) - s.Require().Len(bulk.AssetRemovals, 0) - s.Require().Len(bulk.RelationRemovals, 0) + s.Require().Len(bulk.GetAssetUpserts(), 2) + s.Require().Len(bulk.GetRelationUpserts(), 1) + s.Require().Len(bulk.GetAssetRemovals(), 0) + s.Require().Len(bulk.GetRelationRemovals(), 0) - s.Assert().ElementsMatch(bulk.AssetUpserts, []Asset{Asset(ip2), Asset(ip1)}) - s.Assert().ElementsMatch(bulk.RelationUpserts, []Relation{rel}) + s.Assert().ElementsMatch(bulk.GetAssetUpserts(), []Asset{Asset(ip2), Asset(ip1)}) + s.Assert().ElementsMatch(bulk.GetRelationUpserts(), []Relation{rel}) } func (s *SourceUpdatesSuite) TestShouldUpsertAssets() { @@ -41,12 +41,12 @@ func (s *SourceUpdatesSuite) TestShouldUpsertAssets() { bulk := GenerateGraphUpdatesBulk(g1, g2) - s.Require().Len(bulk.AssetUpserts, 2) - s.Require().Len(bulk.RelationUpserts, 0) - s.Require().Len(bulk.AssetRemovals, 0) - s.Require().Len(bulk.RelationRemovals, 0) + s.Require().Len(bulk.GetAssetUpserts(), 2) + s.Require().Len(bulk.GetRelationUpserts(), 0) + s.Require().Len(bulk.GetAssetRemovals(), 0) + s.Require().Len(bulk.GetRelationRemovals(), 0) - s.Assert().ElementsMatch(bulk.AssetUpserts, []Asset{Asset(ip3), Asset(ip4)}) + s.Assert().ElementsMatch(bulk.GetAssetUpserts(), []Asset{Asset(ip3), Asset(ip4)}) } func (s *SourceUpdatesSuite) TestShouldUpsertRelations() { @@ -63,13 +63,13 @@ func (s *SourceUpdatesSuite) TestShouldUpsertRelations() { bulk := GenerateGraphUpdatesBulk(g1, g2) - s.Require().Len(bulk.AssetUpserts, 1) - s.Require().Len(bulk.RelationUpserts, 2) - s.Require().Len(bulk.AssetRemovals, 0) - s.Require().Len(bulk.RelationRemovals, 0) + s.Require().Len(bulk.GetAssetUpserts(), 1) + s.Require().Len(bulk.GetRelationUpserts(), 2) + s.Require().Len(bulk.GetAssetRemovals(), 0) + s.Require().Len(bulk.GetRelationRemovals(), 0) - s.Assert().ElementsMatch(bulk.AssetUpserts, []Asset{Asset(ip3)}) - s.Assert().ElementsMatch(bulk.RelationUpserts, []Relation{r1, r2}) + s.Assert().ElementsMatch(bulk.GetAssetUpserts(), []Asset{Asset(ip3)}) + s.Assert().ElementsMatch(bulk.GetRelationUpserts(), []Relation{r1, r2}) } func (s *SourceUpdatesSuite) TestShouldRemoveGraph() { @@ -80,13 +80,13 @@ func (s *SourceUpdatesSuite) TestShouldRemoveGraph() { bulk := GenerateGraphUpdatesBulk(g1, nil) - s.Require().Len(bulk.AssetUpserts, 0) - s.Require().Len(bulk.RelationUpserts, 0) - s.Require().Len(bulk.AssetRemovals, 2) - s.Require().Len(bulk.RelationRemovals, 1) + s.Require().Len(bulk.GetAssetUpserts(), 0) + s.Require().Len(bulk.GetRelationUpserts(), 0) + s.Require().Len(bulk.GetAssetRemovals(), 2) + s.Require().Len(bulk.GetRelationRemovals(), 1) - s.Assert().ElementsMatch(bulk.AssetRemovals, []Asset{Asset(ip1), Asset(ip2)}) - s.Assert().ElementsMatch(bulk.RelationRemovals, []Relation{r}) + s.Assert().ElementsMatch(bulk.GetAssetRemovals(), []Asset{Asset(ip1), Asset(ip2)}) + s.Assert().ElementsMatch(bulk.GetRelationRemovals(), []Relation{r}) } func (s *SourceUpdatesSuite) TestShouldGenerateBulkOfSubgraph() { @@ -100,13 +100,13 @@ func (s *SourceUpdatesSuite) TestShouldGenerateBulkOfSubgraph() { bulk := GenerateGraphUpdatesBulk(g1, g2) - s.Require().Len(bulk.AssetUpserts, 0) - s.Require().Len(bulk.RelationUpserts, 0) - s.Require().Len(bulk.AssetRemovals, 1) - s.Require().Len(bulk.RelationRemovals, 1) + s.Require().Len(bulk.GetAssetUpserts(), 0) + s.Require().Len(bulk.GetRelationUpserts(), 0) + s.Require().Len(bulk.GetAssetRemovals(), 1) + s.Require().Len(bulk.GetRelationRemovals(), 1) - s.Assert().ElementsMatch(bulk.AssetRemovals, []Asset{Asset(ip2)}) - s.Assert().ElementsMatch(bulk.RelationRemovals, []Relation{r}) + s.Assert().ElementsMatch(bulk.GetAssetRemovals(), []Asset{Asset(ip2)}) + s.Assert().ElementsMatch(bulk.GetRelationRemovals(), []Relation{r}) } func (s *SourceUpdatesSuite) TestShouldGenerateBulkForMixedAdditionsAndRemovals() { @@ -122,15 +122,15 @@ func (s *SourceUpdatesSuite) TestShouldGenerateBulkForMixedAdditionsAndRemovals( bulk := GenerateGraphUpdatesBulk(g1, g2) - s.Require().Len(bulk.AssetUpserts, 1) - s.Require().Len(bulk.RelationUpserts, 1) - s.Require().Len(bulk.AssetRemovals, 1) - s.Require().Len(bulk.RelationRemovals, 1) + s.Require().Len(bulk.GetAssetUpserts(), 1) + s.Require().Len(bulk.GetRelationUpserts(), 1) + s.Require().Len(bulk.GetAssetRemovals(), 1) + s.Require().Len(bulk.GetRelationRemovals(), 1) - s.Assert().ElementsMatch(bulk.AssetUpserts, []Asset{Asset(ip3)}) - s.Assert().ElementsMatch(bulk.RelationUpserts, []Relation{r2}) - s.Assert().ElementsMatch(bulk.AssetRemovals, []Asset{Asset(ip2)}) - s.Assert().ElementsMatch(bulk.RelationRemovals, []Relation{r}) + s.Assert().ElementsMatch(bulk.GetAssetUpserts(), []Asset{Asset(ip3)}) + s.Assert().ElementsMatch(bulk.GetRelationUpserts(), []Relation{r2}) + s.Assert().ElementsMatch(bulk.GetAssetRemovals(), []Asset{Asset(ip2)}) + s.Assert().ElementsMatch(bulk.GetRelationRemovals(), []Relation{r}) } func TestGraphUpdatesSuite(t *testing.T) { diff --git a/internal/knowledge/transaction.go b/internal/knowledge/transaction.go index 5997ea6..e563610 100644 --- a/internal/knowledge/transaction.go +++ b/internal/knowledge/transaction.go @@ -7,7 +7,7 @@ import ( ) type GraphUpdateRequestBody struct { - Updates GraphUpdatesBulk `json:"updates"` + Updates *GraphUpdatesBulk `json:"updates"` Schema schema.SchemaGraph `json:"schema"` } diff --git a/internal/schema/graph.go b/internal/schema/graph.go index ca55b80..696e413 100644 --- a/internal/schema/graph.go +++ b/internal/schema/graph.go @@ -83,7 +83,7 @@ func (sg *SchemaGraph) Equal(other SchemaGraph) bool { } func (sg *SchemaGraph) MarshalJSON() ([]byte, error) { - schemaJson := new(SchemaGraphJSON) + schemaJson := SchemaGraphJSON{} schemaJson.Vertices = []AssetType{} schemaJson.Edges = []RelationType{} diff --git a/internal/server/server.go b/internal/server/server.go index 38a6986..1266813 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -298,7 +298,7 @@ func postGraphUpdates(graphUpdatesC chan knowledge.SourceSubGraphUpdates) http.H // TODO(c.michaud): verify compatibility of the schema with graph updates graphUpdatesC <- knowledge.SourceSubGraphUpdates{ - Updates: requestBody.Updates, + Updates: *requestBody.Updates, Schema: requestBody.Schema, Source: source, } diff --git a/internal/utils/task.go b/internal/utils/task.go index d59268b..1a34972 100644 --- a/internal/utils/task.go +++ b/internal/utils/task.go @@ -6,9 +6,9 @@ import ( // RecurrentTask represent a recurrent task type RecurrentTask struct { - cancelChannel chan bool - interval time.Duration - callback func() + finishC chan struct{} + interval time.Duration + callback func() RunAtStartup bool } @@ -16,31 +16,30 @@ type RecurrentTask struct { // NewRecurrentTask create a recurrent task func NewRecurrentTask(interval time.Duration, callback func()) RecurrentTask { return RecurrentTask{ - cancelChannel: make(chan bool), - interval: interval, - callback: callback, + interval: interval, + callback: callback, + finishC: make(chan struct{}), } } // Start a recurrent task func (rt *RecurrentTask) Start() { - go func() { - if rt.RunAtStartup { - rt.callback() - } - for { - select { - case <-rt.cancelChannel: - break - case <-time.After(rt.interval): - rt.callback() - } + if rt.RunAtStartup { + rt.callback() + } + + for { + select { + case <-rt.finishC: + return + case <-time.After(rt.interval): + rt.callback() } - }() + } } // Stop the recurrent task func (rt *RecurrentTask) Stop() { - rt.cancelChannel <- true + rt.finishC <- struct{}{} }