diff --git a/asset.go b/asset.go index de6184e..d703da0 100644 --- a/asset.go +++ b/asset.go @@ -4,10 +4,6 @@ package open_asset_model -import ( - "strings" -) - type Asset interface { Key() string AssetType() AssetType @@ -41,216 +37,3 @@ var AssetList = []AssetType{ AutnumRecord, AutonomousSystem, ContactRecord, DomainRecord, EmailAddress, File, Fingerprint, FQDN, IPAddress, IPNetRecord, Location, Netblock, Organization, Person, Phone, Service, TLSCertificate, URL, } - -var locationRels = map[string]map[RelationType][]AssetType{} - -var phoneRels = map[string]map[RelationType][]AssetType{} - -var emailRels = map[string]map[RelationType][]AssetType{ - "domain": {SimpleRelation: {FQDN}}, -} - -var domainRecordRels = map[string]map[RelationType][]AssetType{ - "name_server": {SimpleRelation: {FQDN}}, - "whois_server": {SimpleRelation: {FQDN}}, - "registrar_contact": {SimpleRelation: {ContactRecord}}, - "registrant_contact": {SimpleRelation: {ContactRecord}}, - "admin_contact": {SimpleRelation: {ContactRecord}}, - "technical_contact": {SimpleRelation: {ContactRecord}}, - "billing_contact": {SimpleRelation: {ContactRecord}}, - "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, -} - -var autnumRecordRels = map[string]map[RelationType][]AssetType{ - "whois_server": {SimpleRelation: {FQDN}}, - "registrant": {SimpleRelation: {ContactRecord}}, - "admin_contact": {SimpleRelation: {ContactRecord}}, - "abuse_contact": {SimpleRelation: {ContactRecord}}, - "technical_contact": {SimpleRelation: {ContactRecord}}, - "rdap_url": {SimpleRelation: {URL}}, - "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, -} - -var ipnetRecordRels = map[string]map[RelationType][]AssetType{ - "whois_server": {SimpleRelation: {FQDN}}, - "registrant": {SimpleRelation: {ContactRecord}}, - "admin_contact": {SimpleRelation: {ContactRecord}}, - "abuse_contact": {SimpleRelation: {ContactRecord}}, - "technical_contact": {SimpleRelation: {ContactRecord}}, - "rdap_url": {SimpleRelation: {URL}}, - "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, -} - -var personRels = map[string]map[RelationType][]AssetType{} - -var orgRels = map[string]map[RelationType][]AssetType{} - -var ipRels = map[string]map[RelationType][]AssetType{ - "port": {PortRelation: {Service}}, - "ptr_record": {SimpleRelation: {FQDN}}, -} - -var netblockRels = map[string]map[RelationType][]AssetType{ - "contains": {SimpleRelation: {IPAddress}}, - "registration": {SimpleRelation: {IPNetRecord}}, -} - -var autonomousSystemRels = map[string]map[RelationType][]AssetType{ - "announces": {SimpleRelation: {Netblock}}, - "registration": {SimpleRelation: {AutnumRecord}}, -} - -var fileRels = map[string]map[RelationType][]AssetType{ - "url": {SimpleRelation: {URL}}, - "contains": {SimpleRelation: {ContactRecord, URL}}, -} - -var fqdnRels = map[string]map[RelationType][]AssetType{ - "port": {PortRelation: {Service}}, - "dns_record": { - BasicDNSRelation: {FQDN, IPAddress}, - PrefDNSRelation: {FQDN}, - SRVDNSRelation: {FQDN}, - }, - "node": {SimpleRelation: {FQDN}}, - "registration": {SimpleRelation: {DomainRecord}}, -} - -var tlscertRels = map[string]map[RelationType][]AssetType{ - "common_name": {SimpleRelation: {FQDN}}, - "subject_contact": {SimpleRelation: {ContactRecord}}, - "issuer_contact": {SimpleRelation: {ContactRecord}}, - "san_dns_name": {SimpleRelation: {FQDN}}, - "san_email_address": {SimpleRelation: {EmailAddress}}, - "san_ip_address": {SimpleRelation: {IPAddress}}, - "san_url": {SimpleRelation: {URL}}, - "issuing_certificate": {SimpleRelation: {TLSCertificate}}, - "issuing_certificate_url": {SimpleRelation: {URL}}, - "ocsp_server": {SimpleRelation: {URL}}, - "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, -} - -var urlRels = map[string]map[RelationType][]AssetType{ - "domain": {SimpleRelation: {FQDN}}, - "ip_address": {SimpleRelation: {IPAddress}}, - "port": {PortRelation: {Service}}, - "file": {SimpleRelation: {File}}, -} - -var fingerprintRels = map[string]map[RelationType][]AssetType{} - -var contactRecordRels = map[string]map[RelationType][]AssetType{ - "person": {SimpleRelation: {Person}}, - "organization": {SimpleRelation: {Organization}}, - "location": {SimpleRelation: {Location}}, - "email": {SimpleRelation: {EmailAddress}}, - "phone": {SimpleRelation: {Phone}}, - "url": {SimpleRelation: {URL}}, -} - -var serviceRels = map[string]map[RelationType][]AssetType{ - "fingerprint": {SimpleRelation: {Fingerprint}}, - "certificate": {SimpleRelation: {TLSCertificate}}, -} - -// GetAssetOutgoingRelations returns the relation types allowed to be used -// when the subject is the asset type provided in the parameter. -// Providing an invalid subject causes a return value of nil. -func GetAssetOutgoingRelations(subject AssetType) []string { - relations := assetTypeRelations(subject) - if relations == nil { - return nil - } - - var rtypes []string - for k := range relations { - rtypes = append(rtypes, k) - } - return rtypes -} - -// GetAssetOutgoingRelations returns the relation types allowed to be used -// when the subject is the asset type provided in the parameter. -// Providing an invalid subject causes a return value of nil. -func GetTransformAssetTypes(subject AssetType, relation string) []AssetType { - relations := assetTypeRelations(subject) - if relations == nil { - return nil - } - - var results []AssetType - rtype := strings.ToLower(relation) - m := make(map[AssetType]struct{}) - for _, atypes := range relations[rtype] { - for _, t := range atypes { - if _, found := m[t]; !found { - m[t] = struct{}{} - results = append(results, t) - } - } - } - return results -} - -func assetTypeRelations(atype AssetType) map[string]map[RelationType][]AssetType { - var relations map[string]map[RelationType][]AssetType - - switch atype { - case IPAddress: - relations = ipRels - case Netblock: - relations = netblockRels - case AutonomousSystem: - relations = autonomousSystemRels - case File: - relations = fileRels - case FQDN: - relations = fqdnRels - case DomainRecord: - relations = domainRecordRels - case AutnumRecord: - relations = autnumRecordRels - case IPNetRecord: - relations = ipnetRecordRels - case Location: - relations = locationRels - case Phone: - relations = phoneRels - case EmailAddress: - relations = emailRels - case Person: - relations = personRels - case Organization: - relations = orgRels - case TLSCertificate: - relations = tlscertRels - case URL: - relations = urlRels - case Fingerprint: - relations = fingerprintRels - case ContactRecord: - relations = contactRecordRels - case Service: - relations = serviceRels - default: - return nil - } - - return relations -} - -// ValidRelationship returns true if the relation is valid in the taxonomy -// when outgoing from the source asset type to the destination asset type. -func ValidRelationship(src AssetType, relation string, destination AssetType) bool { - atypes := GetTransformAssetTypes(src, relation) - if atypes == nil { - return false - } - - for _, atype := range atypes { - if atype == destination { - return true - } - } - return false -} diff --git a/relation.go b/relation.go index 5b9a0e1..d8bcff3 100644 --- a/relation.go +++ b/relation.go @@ -4,7 +4,12 @@ package open_asset_model +import ( + "strings" +) + type Relation interface { + Label() string RelationType() RelationType JSON() ([]byte, error) } @@ -22,3 +27,220 @@ const ( var RelationList = []RelationType{ BasicDNSRelation, PortRelation, PrefDNSRelation, SimpleRelation, SRVDNSRelation, } + +var locationRels = map[string]map[RelationType][]AssetType{} + +var phoneRels = map[string]map[RelationType][]AssetType{} + +var emailRels = map[string]map[RelationType][]AssetType{ + "domain": {SimpleRelation: {FQDN}}, +} + +var domainRecordRels = map[string]map[RelationType][]AssetType{ + "name_server": {SimpleRelation: {FQDN}}, + "whois_server": {SimpleRelation: {FQDN}}, + "registrar_contact": {SimpleRelation: {ContactRecord}}, + "registrant_contact": {SimpleRelation: {ContactRecord}}, + "admin_contact": {SimpleRelation: {ContactRecord}}, + "technical_contact": {SimpleRelation: {ContactRecord}}, + "billing_contact": {SimpleRelation: {ContactRecord}}, + "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, +} + +var autnumRecordRels = map[string]map[RelationType][]AssetType{ + "whois_server": {SimpleRelation: {FQDN}}, + "registrant": {SimpleRelation: {ContactRecord}}, + "admin_contact": {SimpleRelation: {ContactRecord}}, + "abuse_contact": {SimpleRelation: {ContactRecord}}, + "technical_contact": {SimpleRelation: {ContactRecord}}, + "rdap_url": {SimpleRelation: {URL}}, + "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, +} + +var ipnetRecordRels = map[string]map[RelationType][]AssetType{ + "whois_server": {SimpleRelation: {FQDN}}, + "registrant": {SimpleRelation: {ContactRecord}}, + "admin_contact": {SimpleRelation: {ContactRecord}}, + "abuse_contact": {SimpleRelation: {ContactRecord}}, + "technical_contact": {SimpleRelation: {ContactRecord}}, + "rdap_url": {SimpleRelation: {URL}}, + "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, +} + +var personRels = map[string]map[RelationType][]AssetType{} + +var orgRels = map[string]map[RelationType][]AssetType{} + +var ipRels = map[string]map[RelationType][]AssetType{ + "port": {PortRelation: {Service}}, + "ptr_record": {SimpleRelation: {FQDN}}, +} + +var netblockRels = map[string]map[RelationType][]AssetType{ + "contains": {SimpleRelation: {IPAddress}}, + "registration": {SimpleRelation: {IPNetRecord}}, +} + +var autonomousSystemRels = map[string]map[RelationType][]AssetType{ + "announces": {SimpleRelation: {Netblock}}, + "registration": {SimpleRelation: {AutnumRecord}}, +} + +var fileRels = map[string]map[RelationType][]AssetType{ + "url": {SimpleRelation: {URL}}, + "contains": {SimpleRelation: {ContactRecord, URL}}, +} + +var fqdnRels = map[string]map[RelationType][]AssetType{ + "port": {PortRelation: {Service}}, + "dns_record": { + BasicDNSRelation: {FQDN, IPAddress}, + PrefDNSRelation: {FQDN}, + SRVDNSRelation: {FQDN}, + }, + "node": {SimpleRelation: {FQDN}}, + "registration": {SimpleRelation: {DomainRecord}}, +} + +var tlscertRels = map[string]map[RelationType][]AssetType{ + "common_name": {SimpleRelation: {FQDN}}, + "subject_contact": {SimpleRelation: {ContactRecord}}, + "issuer_contact": {SimpleRelation: {ContactRecord}}, + "san_dns_name": {SimpleRelation: {FQDN}}, + "san_email_address": {SimpleRelation: {EmailAddress}}, + "san_ip_address": {SimpleRelation: {IPAddress}}, + "san_url": {SimpleRelation: {URL}}, + "issuing_certificate": {SimpleRelation: {TLSCertificate}}, + "issuing_certificate_url": {SimpleRelation: {URL}}, + "ocsp_server": {SimpleRelation: {URL}}, + "associated_with": {SimpleRelation: {AutnumRecord, DomainRecord, IPNetRecord}}, +} + +var urlRels = map[string]map[RelationType][]AssetType{ + "domain": {SimpleRelation: {FQDN}}, + "ip_address": {SimpleRelation: {IPAddress}}, + "port": {PortRelation: {Service}}, + "file": {SimpleRelation: {File}}, +} + +var fingerprintRels = map[string]map[RelationType][]AssetType{} + +var contactRecordRels = map[string]map[RelationType][]AssetType{ + "person": {SimpleRelation: {Person}}, + "organization": {SimpleRelation: {Organization}}, + "location": {SimpleRelation: {Location}}, + "email": {SimpleRelation: {EmailAddress}}, + "phone": {SimpleRelation: {Phone}}, + "url": {SimpleRelation: {URL}}, +} + +var serviceRels = map[string]map[RelationType][]AssetType{ + "fingerprint": {SimpleRelation: {Fingerprint}}, + "certificate": {SimpleRelation: {TLSCertificate}}, +} + +// GetAssetOutgoingRelations returns the relation types allowed to be used +// when the subject is the asset type provided in the parameter. +// Providing an invalid subject causes a return value of nil. +func GetAssetOutgoingRelations(subject AssetType) []string { + relations := assetTypeRelations(subject) + if relations == nil { + return nil + } + + var rtypes []string + for k := range relations { + rtypes = append(rtypes, k) + } + return rtypes +} + +// GetTransformAssetTypes returns the asset types allowed to be assigned +// when the subject is the asset type provided in the parameter, along +// with the provided label and RelationType. +// Providing an invalid subject causes a return value of nil. +func GetTransformAssetTypes(subject AssetType, label string, rtype RelationType) []AssetType { + relations := assetTypeRelations(subject) + if relations == nil { + return nil + } + + var results []AssetType + label = strings.ToLower(label) + m := make(map[AssetType]struct{}) + for r, atypes := range relations[label] { + if r != rtype { + continue + } + for _, t := range atypes { + if _, found := m[t]; !found { + m[t] = struct{}{} + results = append(results, t) + } + } + } + return results +} + +func assetTypeRelations(atype AssetType) map[string]map[RelationType][]AssetType { + var relations map[string]map[RelationType][]AssetType + + switch atype { + case IPAddress: + relations = ipRels + case Netblock: + relations = netblockRels + case AutonomousSystem: + relations = autonomousSystemRels + case File: + relations = fileRels + case FQDN: + relations = fqdnRels + case DomainRecord: + relations = domainRecordRels + case AutnumRecord: + relations = autnumRecordRels + case IPNetRecord: + relations = ipnetRecordRels + case Location: + relations = locationRels + case Phone: + relations = phoneRels + case EmailAddress: + relations = emailRels + case Person: + relations = personRels + case Organization: + relations = orgRels + case TLSCertificate: + relations = tlscertRels + case URL: + relations = urlRels + case Fingerprint: + relations = fingerprintRels + case ContactRecord: + relations = contactRecordRels + case Service: + relations = serviceRels + default: + return nil + } + + return relations +} + +// ValidRelationship returns true if the relation is valid in the taxonomy +// when outgoing from the source asset type to the destination asset type. +func ValidRelationship(src AssetType, label string, rtype RelationType, destination AssetType) bool { + atypes := GetTransformAssetTypes(src, label, rtype) + if atypes == nil { + return false + } + + for _, atype := range atypes { + if atype == destination { + return true + } + } + return false +} diff --git a/relation/dns.go b/relation/dns.go index 0df7c74..9b76d59 100644 --- a/relation/dns.go +++ b/relation/dns.go @@ -18,9 +18,15 @@ type RRHeader struct { // BasicDNSRelation is a relation in the graph representing a basic DNS resource record. type BasicDNSRelation struct { + Name string `json:"label"` Header RRHeader `json:"header"` } +// RelationType implements the Relation interface. +func (r BasicDNSRelation) Label() string { + return r.Name +} + // RelationType implements the Relation interface. func (r BasicDNSRelation) RelationType() model.RelationType { return model.BasicDNSRelation @@ -33,10 +39,16 @@ func (r BasicDNSRelation) JSON() ([]byte, error) { // PrefDNSRelation is a relation in the graph representing a DNS resource record with preference information. type PrefDNSRelation struct { + Name string `json:"label"` Header RRHeader `json:"header"` Preference int `json:"preference"` } +// RelationType implements the Relation interface. +func (r PrefDNSRelation) Label() string { + return r.Name +} + // RelationType implements the Relation interface. func (r PrefDNSRelation) RelationType() model.RelationType { return model.PrefDNSRelation @@ -49,12 +61,18 @@ func (r PrefDNSRelation) JSON() ([]byte, error) { // SRVDNSRelation is a relation in the graph representing a DNS SRV resource record. type SRVDNSRelation struct { + Name string `json:"label"` Header RRHeader `json:"header"` Priority int `json:"priority"` Weight int `json:"weight"` Port int `json:"port"` } +// RelationType implements the Relation interface. +func (r SRVDNSRelation) Label() string { + return r.Name +} + // RelationType implements the Relation interface. func (r SRVDNSRelation) RelationType() model.RelationType { return model.SRVDNSRelation diff --git a/relation/dns_test.go b/relation/dns_test.go index 51ccda1..011dbdc 100644 --- a/relation/dns_test.go +++ b/relation/dns_test.go @@ -11,6 +11,15 @@ import ( "github.com/stretchr/testify/require" ) +func TestBasicDNSRelationName(t *testing.T) { + want := "dns_record" + br := BasicDNSRelation{Name: want} + + if got := br.Label(); got != want { + t.Errorf("BasicDNSRelation.Label() = %v, want %v", got, want) + } +} + func TestBasicDNSRelationImplementsRelation(t *testing.T) { var _ model.Relation = BasicDNSRelation{} // Verify proper implementation of the Relation interface var _ model.Relation = (*BasicDNSRelation)(nil) // Verify *BasicDNSRelation properly implements the Relation interface. @@ -19,6 +28,7 @@ func TestBasicDNSRelationImplementsRelation(t *testing.T) { func TestBasicDNSRelation(t *testing.T) { t.Run("Test successful creation of BasicDNSRelation with valid resource record header", func(t *testing.T) { dr := BasicDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -26,6 +36,7 @@ func TestBasicDNSRelation(t *testing.T) { }, } + require.Equal(t, "dns_record", dr.Name) require.Equal(t, 1, dr.Header.RRType) require.Equal(t, 1, dr.Header.Class) require.Equal(t, 86400, dr.Header.TTL) @@ -34,6 +45,7 @@ func TestBasicDNSRelation(t *testing.T) { t.Run("Test successful JSON serialization of BasicDNSRelation with valid resource record header", func(t *testing.T) { dr := BasicDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -44,10 +56,19 @@ func TestBasicDNSRelation(t *testing.T) { jsonData, err := dr.JSON() require.NoError(t, err) - require.JSONEq(t, `{"header":{"rr_type":1, "class":1, "ttl":86400}}`, string(jsonData)) + require.JSONEq(t, `{"label":"dns_record", "header":{"rr_type":1, "class":1, "ttl":86400}}`, string(jsonData)) }) } +func TestPrefDNSRelationName(t *testing.T) { + want := "dns_record" + br := PrefDNSRelation{Name: want} + + if got := br.Label(); got != want { + t.Errorf("PrefDNSRelation.Label() = %v, want %v", got, want) + } +} + func TestPrefDNSRelationImplementsRelation(t *testing.T) { var _ model.Relation = PrefDNSRelation{} // Verify proper implementation of the Relation interface var _ model.Relation = (*PrefDNSRelation)(nil) // Verify *PrefDNSRelation properly implements the Relation interface. @@ -56,6 +77,7 @@ func TestPrefDNSRelationImplementsRelation(t *testing.T) { func TestPrefDNSRelation(t *testing.T) { t.Run("Test successful creation of PrefDNSRelation with valid resource record header and preference", func(t *testing.T) { pr := PrefDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -64,6 +86,7 @@ func TestPrefDNSRelation(t *testing.T) { Preference: 5, } + require.Equal(t, "dns_record", pr.Name) require.Equal(t, 1, pr.Header.RRType) require.Equal(t, 1, pr.Header.Class) require.Equal(t, 86400, pr.Header.TTL) @@ -73,6 +96,7 @@ func TestPrefDNSRelation(t *testing.T) { t.Run("Test successful JSON serialization of PrefDNSRelation with valid resource record header and preference", func(t *testing.T) { pr := PrefDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -84,10 +108,19 @@ func TestPrefDNSRelation(t *testing.T) { jsonData, err := pr.JSON() require.NoError(t, err) - require.JSONEq(t, `{"header":{"rr_type":1, "class":1, "ttl":86400}, "preference":5}`, string(jsonData)) + require.JSONEq(t, `{"label":"dns_record", "header":{"rr_type":1, "class":1, "ttl":86400}, "preference":5}`, string(jsonData)) }) } +func TestSRVDNSRelationName(t *testing.T) { + want := "dns_record" + br := SRVDNSRelation{Name: want} + + if got := br.Label(); got != want { + t.Errorf("SRVDNSRelation.Label() = %v, want %v", got, want) + } +} + func TestSRVDNSRelationImplementsRelation(t *testing.T) { var _ model.Relation = SRVDNSRelation{} // Verify proper implementation of the Relation interface var _ model.Relation = (*SRVDNSRelation)(nil) // Verify *SRVDNSRelation properly implements the Relation interface. @@ -96,6 +129,7 @@ func TestSRVDNSRelationImplementsRelation(t *testing.T) { func TestSRVDNSRelation(t *testing.T) { t.Run("Test successful creation of SRVDNSRelation", func(t *testing.T) { sr := SRVDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -106,6 +140,7 @@ func TestSRVDNSRelation(t *testing.T) { Port: 80, } + require.Equal(t, "dns_record", sr.Name) require.Equal(t, 1, sr.Header.RRType) require.Equal(t, 1, sr.Header.Class) require.Equal(t, 86400, sr.Header.TTL) @@ -117,6 +152,7 @@ func TestSRVDNSRelation(t *testing.T) { t.Run("Test successful JSON serialization of SRVDNSRelation", func(t *testing.T) { sr := SRVDNSRelation{ + Name: "dns_record", Header: RRHeader{ RRType: 1, Class: 1, @@ -130,6 +166,6 @@ func TestSRVDNSRelation(t *testing.T) { jsonData, err := sr.JSON() require.NoError(t, err) - require.JSONEq(t, `{"header":{"rr_type":1, "class":1, "ttl":86400}, "priority":10, "weight":5, "port":80}`, string(jsonData)) + require.JSONEq(t, `{"label":"dns_record", "header":{"rr_type":1, "class":1, "ttl":86400}, "priority":10, "weight":5, "port":80}`, string(jsonData)) }) } diff --git a/relation/port.go b/relation/port.go index 40473de..1e21bf9 100644 --- a/relation/port.go +++ b/relation/port.go @@ -12,10 +12,16 @@ import ( // PortRelation is a relation in the graph representing an open port. type PortRelation struct { + Name string `json:"label"` PortNumber int `json:"port_number"` Protocol string `json:"protocol"` } +// RelationType implements the Relation interface. +func (r PortRelation) Label() string { + return r.Name +} + // RelationType implements the Relation interface. func (r PortRelation) RelationType() model.RelationType { return model.PortRelation diff --git a/relation/port_test.go b/relation/port_test.go index 731816e..7295b50 100644 --- a/relation/port_test.go +++ b/relation/port_test.go @@ -11,6 +11,15 @@ import ( "github.com/stretchr/testify/require" ) +func TestPortRelationName(t *testing.T) { + want := "port" + pr := PortRelation{Name: want} + + if got := pr.Label(); got != want { + t.Errorf("PortRelation.Label() = %v, want %v", got, want) + } +} + func TestPortRelationImplementsRelation(t *testing.T) { var _ model.Relation = PortRelation{} // Verify proper implementation of the Relation interface var _ model.Relation = (*PortRelation)(nil) // Verify *PortRelation properly implements the Relation interface. @@ -19,10 +28,12 @@ func TestPortRelationImplementsRelation(t *testing.T) { func TestPortRelation(t *testing.T) { t.Run("Test successful creation of PortRelation with valid port number and protocol", func(t *testing.T) { pr := PortRelation{ + Name: "port", PortNumber: 80, Protocol: "tcp", } + require.Equal(t, "port", pr.Name) require.Equal(t, 80, pr.PortNumber) require.Equal(t, "tcp", pr.Protocol) require.Equal(t, pr.RelationType(), model.PortRelation) @@ -30,6 +41,7 @@ func TestPortRelation(t *testing.T) { t.Run("Test successful JSON serialization of PortRelation with valid port number and protocol", func(t *testing.T) { pr := PortRelation{ + Name: "port", PortNumber: 80, Protocol: "tcp", } @@ -37,6 +49,6 @@ func TestPortRelation(t *testing.T) { jsonData, err := pr.JSON() require.NoError(t, err) - require.JSONEq(t, `{"port_number":80, "protocol":"tcp"}`, string(jsonData)) + require.JSONEq(t, `{"label":"port", "port_number":80, "protocol":"tcp"}`, string(jsonData)) }) } diff --git a/relation/simple.go b/relation/simple.go index a3c8769..f2b8dd9 100644 --- a/relation/simple.go +++ b/relation/simple.go @@ -11,7 +11,14 @@ import ( ) // SimpleRelation represents a simple relation in the graph with no additional data required. -type SimpleRelation struct{} +type SimpleRelation struct { + Name string `json:"label"` +} + +// RelationType implements the Relation interface. +func (r SimpleRelation) Label() string { + return r.Name +} // RelationType implements the Relation interface. func (r SimpleRelation) RelationType() model.RelationType { diff --git a/relation/simple_test.go b/relation/simple_test.go index 08642f2..9c2bb8b 100644 --- a/relation/simple_test.go +++ b/relation/simple_test.go @@ -11,6 +11,15 @@ import ( "github.com/stretchr/testify/require" ) +func TestSimpleRelationName(t *testing.T) { + want := "anything" + sr := SimpleRelation{Name: want} + + if got := sr.Label(); got != want { + t.Errorf("SimpleRelation.Label() = %v, want %v", got, want) + } +} + func TestSimpleRelationImplementsRelation(t *testing.T) { var _ model.Relation = SimpleRelation{} // Verify proper implementation of the Relation interface var _ model.Relation = (*SimpleRelation)(nil) // Verify *SimpleRelation properly implements the Relation interface. @@ -18,17 +27,18 @@ func TestSimpleRelationImplementsRelation(t *testing.T) { func TestSimpleRelation(t *testing.T) { t.Run("Test successful creation of SimpleRelation", func(t *testing.T) { - sr := SimpleRelation{} + sr := SimpleRelation{Name: "anything"} + require.Equal(t, "anything", sr.Name) require.Equal(t, sr.RelationType(), model.SimpleRelation) }) t.Run("Test successful JSON serialization of SimpleRelation", func(t *testing.T) { - sr := SimpleRelation{} + sr := SimpleRelation{Name: "anything"} jsonData, err := sr.JSON() require.NoError(t, err) - require.JSONEq(t, `{}`, string(jsonData)) + require.JSONEq(t, `{"label":"anything"}`, string(jsonData)) }) }