Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attack Tree render code #125

Merged
merged 12 commits into from
Jun 28, 2024
19 changes: 17 additions & 2 deletions src/attack_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,15 @@ def graphviz(args):
"""
path = Path(args.attack_flow)
flow_bundle = attack_flow.model.load_attack_flow_bundle(path)
converted = attack_flow.graphviz.convert(flow_bundle)

if (
flow_bundle.get("objects", "")
and attack_flow.model.get_flow_object(flow_bundle).scope == "attack-tree"
):
converted = attack_flow.graphviz.convert_attack_tree(flow_bundle)
else:
converted = attack_flow.graphviz.convert_attack_flow(flow_bundle)

with open(args.output, "w") as out:
out.write(converted)
return 0
Expand All @@ -106,7 +114,14 @@ def mermaid(args):
"""
path = Path(args.attack_flow)
flow_bundle = attack_flow.model.load_attack_flow_bundle(path)
converted = attack_flow.mermaid.convert(flow_bundle)
if (
flow_bundle.get("objects", "")
and attack_flow.model.get_flow_object(flow_bundle).scope == "attack-tree"
):
converted = attack_flow.mermaid.convert_attack_tree(flow_bundle)
else:
converted = attack_flow.mermaid.convert_attack_flow(flow_bundle)

with open(args.output, "w") as out:
out.write(converted)
return 0
Expand Down
141 changes: 140 additions & 1 deletion src/attack_flow/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def label_escape(text):
return graphviz.escape(html.escape(text))


def convert(bundle):
def convert_attack_flow(bundle):
"""
Convert an Attack Flow STIX bundle into Graphviz format.

Expand Down Expand Up @@ -71,6 +71,87 @@ def convert(bundle):
return gv.source


def convert_attack_tree(bundle):
"""
Convert an Attack Flow STIX bundle into Graphviz format.

:param stix2.Bundle flow:
:rtype: str
"""

gv = graphviz.Digraph(graph_attr={"rankdir": "BT"})
gv.body = _get_body_label(bundle)
ignored_ids = get_viz_ignored_ids(bundle)

objects = bundle.objects

id_to_remove = []
ids = []
for i, o in enumerate(objects):
if o.type == "attack-operator":
id_to_remove.append(
{
"id": o.id,
"prev_id": objects[i - 1].id,
"next_id": o.effect_refs[0],
"type": o.operator,
}
)

ids = [i["id"] for i in id_to_remove]
objects = [item for item in objects if item.id not in ids]
new_operator_ids = [i["next_id"] for i in id_to_remove]
for operator in id_to_remove:
for i, o in enumerate(objects):
if o.type == "relationship" and o.source_ref == operator["id"]:
o.source_ref = operator.prev_id
if o.type == "relationship" and o.target_ref == operator["id"]:
o.target_ref = operator.next_id
if o.get("effect_refs") and operator["id"] in o.effect_refs:
for i, j in enumerate(o.effect_refs):
if j == operator["id"]:
o.effect_refs[i] = operator["next_id"]

for o in objects:
logger.debug("Processing object id=%s", o.id)
if o.type == "attack-action":
if o.id in new_operator_ids:
operator_type = [
item["type"] for item in id_to_remove if item["next_id"] == o.id
][0]
gv.node(
o.id,
label=_get_operator_label(o, operator_type),
shape="plaintext",
)
else:
gv.node(
o.id,
_get_attack_tree_action_label(o),
shape="plaintext",
)
for ref in o.get("asset_refs", []):
gv.edge(o.id, ref)
for ref in o.get("effect_refs", []):
gv.edge(o.id, ref)
elif o.type == "attack-asset":
gv.node(o.id, _get_asset_label(o), shape="plaintext")
if object_ref := o.get("object_ref"):
gv.edge(o.id, object_ref, "object")
elif o.type == "attack-condition":
gv.node(o.id, _get_condition_label(o), shape="plaintext")
for ref in o.get("on_true_refs", []):
gv.edge(o.id, ref, "on_true")
for ref in o.get("on_false_refs", []):
gv.edge(o.id, ref, "on_false")
elif o.type == "relationship":
gv.edge(o.source_ref, o.target_ref, o.relationship_type)
elif o.id not in ignored_ids:
gv.node(o.id, _get_builtin_label(o), shape="plaintext")

return gv.source


def _get_body_label(bundle):
flow = get_flow_object(bundle)
author = bundle.get_obj(flow.created_by_ref)[0]
Expand Down Expand Up @@ -119,6 +200,33 @@ def _get_action_label(action):
)


def _get_attack_tree_action_label(action):
"""
Generate the GraphViz label for an action node as a table.

:param action:
:rtype: str
"""
if tid := action.get("technique_id", None):
heading = f"Action: {tid}"
else:
heading = "Action"
description = "<br/>".join(
textwrap.wrap(label_escape(action.get("description", "")), width=40)
)
confidence = confidence_num_to_label(action.get("confidence", 95))
return "".join(
[
'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
f'<TR><TD BGCOLOR="#B40000" COLSPAN="2"><font color="white"><B>{heading}</B></font></TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Name</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{label_escape(action.name)}</TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Description</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{description}</TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Confidence</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{confidence}</TD></TR>',
"</TABLE>>",
]
)


def _get_asset_label(asset):
"""
Generate the GraphViz label for an asset node as a table.
Expand Down Expand Up @@ -184,3 +292,34 @@ def _get_condition_label(condition):
"</TABLE>>",
]
)


def _get_operator_label(action, operator_type):
"""
Generate the GraphViz label for an action node as a table.

:param action:
:rtype: str
"""
if tid := action.get("technique_id", None):
heading = f"{operator_type} {tid}"
else:
heading = f"{operator_type}"
description = "<br/>".join(
textwrap.wrap(label_escape(action.get("description", "")), width=40)
)
confidence = confidence_num_to_label(action.get("confidence", 95))
if operator_type == "AND":
color = "#99ccff"
else:
color = "#9CE67E"
return "".join(
[
'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
f'<TR><TD BGCOLOR="{color}" COLSPAN="2"><B>{heading}</B></TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Name</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{label_escape(action.name)}</TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Description</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{description}</TD></TR>',
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Confidence</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{confidence}</TD></TR>',
"</TABLE>>",
]
)
102 changes: 100 additions & 2 deletions src/attack_flow/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self):
self.classes = dict()
self.nodes = list()
self.edges = list()
self.direction = ""

def add_class(self, class_, shape, style):
self.classes[class_] = (shape, style)
Expand All @@ -33,7 +34,10 @@ def add_edge(self, src_id, target_id, text):
def render(self):
# Mermaid can't handle IDs with hyphens in them:
convert_id = lambda id_: id_.replace("-", "_")
lines = ["graph TB"]
if self.direction:
lines = [f"graph {self.direction}"]
else:
lines = ["graph TB"]

for class_, (_, style) in self.classes.items():
lines.append(f" classDef {class_} {style}")
Expand All @@ -46,6 +50,9 @@ def render(self):
if self.classes[node_class][0] == "circle":
shape_start = "(("
shape_end = "))"
elif self.classes[node_class][0] == "trap":
shape_start = "[/"
shape_end = "\]"
else:
shape_start = "["
shape_end = "]"
Expand All @@ -64,7 +71,7 @@ def render(self):
return "\n".join(lines)


def convert(bundle):
def convert_attack_flow(bundle):
"""
Convert an Attack Flow STIX bundle into Mermaid format.

Expand Down Expand Up @@ -119,3 +126,94 @@ def convert(bundle):
graph.add_node(o.id, "builtin", " - ".join(label_lines))

return graph.render()


def convert_attack_tree(bundle):

"""
Convert an Attack Flow STIX bundle into Mermaid format.

:param stix2.Bundle flow:
:rtype: str
"""
graph = MermaidGraph()
graph.direction = "BT"
graph.add_class("action", "rect", "fill:#B40000, color:white")
graph.add_class("AND", "rect", "fill:#99ccff")
graph.add_class("OR", "trap", "fill:#9CE67E")
graph.add_class("condition", "rect", "fill:#99ff99")
graph.add_class("builtin", "rect", "fill:#cccccc")
ignored_ids = get_viz_ignored_ids(bundle)

objects = bundle.objects
id_to_remove = []
ids = []

for i, o in enumerate(objects):
if o.type == "attack-operator":
id_to_remove.append(
{
"id": o.id,
"prev_id": objects[i - 1].id,
"next_id": o.effect_refs[0],
"type": o.operator,
}
)
ids = [i["id"] for i in id_to_remove]
objects = [item for item in objects if item.id not in ids]
new_operator_ids = [i["next_id"] for i in id_to_remove]
for operator in id_to_remove:
for i, o in enumerate(objects):
if o.type == "relationship" and o.source_ref == operator["id"]:
o.source_ref = operator.prev_id
if o.type == "relationship" and o.target_ref == operator["id"]:
o.target_ref = operator.next_id
if o.get("effect_refs") and operator["id"] in o.effect_refs:
for i, j in enumerate(o.effect_refs):
if j == operator["id"]:
o.effect_refs[i] = operator["next_id"]

for o in bundle.objects:
if o.type == "attack-action":
if tid := o.get("technique_id", None):
name = f"{tid} {o.name}"
else:
name = o.name
if o.id in new_operator_ids:
operator_type = [
item["type"] for item in id_to_remove if item["next_id"] == o.id
][0]
label_lines = [
f"<b>{operator_type}</b>",
f"<b>{name}</b>",
]
graph.add_node(o.id, operator_type, " - ".join(label_lines))
else:
label_lines = [
"<b>Action</b>",
f"<b>{name}</b>",
]
graph.add_node(o.id, "action", " - ".join(label_lines))
for ref in o.get("effect_refs", []):
graph.add_edge(o.id, ref, " ")
elif o.type == "attack-condition":
graph.add_node(o.id, "condition", f"<b>Condition:</b> {o.description}")
for ref in o.get("on_true_refs", []):
graph.add_edge(o.id, ref, "on_true")
for ref in o.get("on_false_refs", []):
graph.add_edge(o.id, ref, "on_false")
elif o.type == "relationship":
graph.add_edge(o.source_ref, o.target_ref, o.relationship_type)
elif o.id not in ignored_ids and o.id not in ids:
type_ = o.type.replace("-", " ").title()
label_lines = [f"<b>{type_}</b>"]
for key, value in o.items():
if key in VIZ_IGNORE_COMMON_PROPERTIES:
continue
key = key.replace("_", " ").title()
if isinstance(value, list):
value = ", ".join(str(v) for v in value)
label_lines.append(f"<b>{key}</b>: {value}")
graph.add_node(o.id, "builtin", " - ".join(label_lines))

return graph.render()
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ const config: AppConfiguration = {
["campaign", "Campaign"],
["threat-actor", "Threat Actor"],
["malware", "Malware"],
["attack-tree", "Attack Tree"],
["other", "Other"]
]
},
Expand Down
Loading