diff --git a/pate_binja/pate.py b/pate_binja/pate.py index c1429091..289e9858 100644 --- a/pate_binja/pate.py +++ b/pate_binja/pate.py @@ -11,7 +11,8 @@ import shlex import signal import sys -from json import JSONDecodeError +import threading +from json import JSONDecodeError, JSONEncoder, JSONDecoder from subprocess import Popen, PIPE, STDOUT, TimeoutExpired from typing import IO, Any, Optional @@ -56,6 +57,8 @@ def __init__(self, filename: os.PathLike, self.trace_file = None self.last_cfar_graph = None + self.traceConstraintModeDone = threading.Event() + def run(self) -> None: if self.filename.endswith(".run-config.json"): self._run_live() @@ -85,7 +88,7 @@ def _run_live(self): script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run-pate.sh") # Need -l to make sure user's env is fully setup (e.g. access to docker and ghc tools). with open(os.path.join(cwd, "lastrun.replay"), "w", encoding='utf-8') as trace: - with Popen(['/bin/bash', '-l', script, '-o', original, '-p', patched, '--json-toplevel'] + args, + with Popen(['/bin/bash', '-l', script, '-o', original, '-p', patched, '--json-toplevel', '--add-trace-constraints'] + args, cwd=cwd, stdin=PIPE, stdout=PIPE, stderr=STDOUT, @@ -441,30 +444,38 @@ def _ask_user(self, prompt: str, choices: list[str]) -> Optional[str]: return choice def command_loop(self): - if self.config_callback: - self.config_callback(self.config) - rec = self.next_json() - self._command('goto_prompt') - while self.command_step(): - pass - self.user.show_message("Pate finished") - - def command_step(self): - # Process one json record from pate try: - rec = self.next_json(gotoPromptAfterNonJson=True) - return self.process_json(rec) + if self.config_callback: + self.config_callback(self.config) + rec = self.next_json() + self._command('goto_prompt') + while self.command_step(): + pass + # Enter trace constraint processing mode + self.traceConstraintModeDone.wait() + self.user.show_message("Pate finished") except EOFError: self.user.show_message("Pate terminated unexpectedly") return False + def command_step(self): + # Process one json record from pate + rec = self.next_json(gotoPromptAfterNonJson=True) + return self.process_json(rec) + def process_json(self, rec): if self.debug_json: print('\nProcessing JSON:') pp.pprint(rec) - if isinstance(rec, dict) and rec.get('this') and rec.get('trace_node_contents'): + if isinstance(rec, dict) and rec.get('this') == 'Regenerate result with new trace constraints?': + # Finish detected + self.user.show_message('\nProcessing verification results ...\n') + self.processFinalResult() + return False + + elif isinstance(rec, dict) and rec.get('this') and rec.get('trace_node_contents'): # Prompt User # TODO: Heuristic for when to update graph. Ask Dan. Maybe add flag to JSON? if rec['this'].startswith('Control flow desynchronization found at') \ @@ -472,6 +483,7 @@ def process_json(self, rec): # Extract flow graph cfar_graph = self.extract_graph() if cfar_graph: + #print('Update last cfar graph') self.last_cfar_graph = cfar_graph self.user.show_cfar_graph(cfar_graph) # Go back to prompt @@ -480,50 +492,6 @@ def process_json(self, rec): choice = self._ask_user_rec(rec) self._command(choice) - elif isinstance(rec, list) and len(rec) > 0 and rec[-1].get('content') == {'node_kind': 'final_result'}: - # Finish detected - self.user.show_message('\nProcessing verification results.\n') - cmd = rec[-1]['index'] - self._command(str(cmd)) - result = self.next_json() - with io.StringIO() as out: - for tnc in result['trace_node_contents']: - eqconds = tnc.get('content', {}).get('eq_conditions', {}).get('map') - if eqconds: - # Found eq conditions - for item in eqconds: - node = item['key'] - eqcond = item['val'] - - node_id = get_graph_node_id(node) - predicate = eqcond['predicate'] - trace_true = eqcond['trace_true'] - trace_false = eqcond['trace_false'] - - #print('CFAR id:', node_id) - - out.write(f'Equivalence condition for {node_id}\n') - pprint_symbolic(out, predicate) - out.write('\n') - - #out.write('\nTrace True\n') - #pprint_node_event_trace(trace_true, 'True Trace', out=out) - - #out.write('\nTrace False\n') - #pprint_node_event_trace(trace_false, 'False Trace', out=out) - - if self.last_cfar_graph: - cfar_node = self.last_cfar_graph.get(node_id) - cfar_node.predicate = predicate - cfar_node.trace_true = trace_true - cfar_node.trace_false = trace_false - - self.user.show_message(out.getvalue()) - if self.last_cfar_graph: - self.user.show_cfar_graph(self.last_cfar_graph) - - return False - # elif (isinstance(rec, dict) and rec.get('this') # and rec.get('trace_node_contents') is not None # and rec['this'].startswith('Assumed Equivalence Conditions')): @@ -569,6 +537,143 @@ def process_json(self, rec): return True + def processFinalResult(self, traceConstraints: list[tuple[TraceVar, str, str]] = None, cfarNode: CFARNode = None): + # TODO: add option to do with respect to a cfar node in which case missing should clear eq cond data? + self._command('up') + rec = self.next_json() + # isinstance(rec, dict) and rec.get('trace_node_kind') == 'final_result': + # Find the last "Toplevel Result" + lastTopLevelResult = None + for tnc in rec['trace_node_contents']: + if tnc.get('pretty') == "Toplevel Result": + lastTopLevelResult = tnc + with io.StringIO() as out: + if not lastTopLevelResult: + out.write(f'No equivalence conditions found\n') + else: + eqconds = lastTopLevelResult.get('content', {}).get('eq_conditions', {}).get('map') + if eqconds: + # Found eq conditions + for item in eqconds: + node = item['key'] + eqcond = item['val'] + + node_id = get_graph_node_id(node) + predicate = eqcond['predicate'] + trace_true = eqcond['trace_true'] + trace_false = eqcond['trace_false'] + trace_footprint = eqcond['trace_footprint'] + + # print('CFAR id:', node_id) + + out.write(f'Equivalence condition for {node_id}\n') + pprint_symbolic(out, predicate) + out.write('\n') + + # out.write('\nTrace True\n') + # pprint_node_event_trace(trace_true, 'True Trace', out=out) + + # out.write('\nTrace False\n') + # pprint_node_event_trace(trace_false, 'False Trace', out=out) + + if self.last_cfar_graph: + cfar_node = self.last_cfar_graph.get(node_id) + # Hack to get unconstrainedPredicate from first top level result + if cfar_node.unconstrainedPredicate is None: + cfar_node.unconstrainedPredicate = predicate + cfar_node.predicate = predicate + cfar_node.trace_true = trace_true + cfar_node.trace_false = trace_false + cfar_node.trace_footprint = trace_footprint + cfar_node.traceConstraints = traceConstraints + + else: + # no eq conditions - unsat constraints + cfarNode.trace_true = False + cfarNode.trace_false = False + cfarNode.traceConstraints = traceConstraints + cfarNode.predicate = cfarNode.unconstrainedPredicate + + self.user.show_message(out.getvalue()) + if self.last_cfar_graph: + self.user.show_cfar_graph(self.last_cfar_graph) + self._command('goto_prompt') + rec = self.next_json() + + def getReplayTraceConstraints(self) -> Optional[list[tuple[TraceVar, str, str]]]: + if self.trace_file is None: + # Read constraints from replay file + # Replay need to do this ahead of time to populate trace constraint dialog. + # Doing it here for now so replay works for debugging. + replay_line = self.pate_proc.stdout.readline() + if replay_line.startswith('Trace Constraints: '): + replay_line = replay_line[len('Trace Constraints: '):].strip() + # Parse JSON and return it + traceConstraints = json.loads(replay_line, object_hook=traceConstraintsJSONObjectHook) + # TODO: replace top level list[3] with tuple[3] + traceConstraints = [tuple(x) for x in traceConstraints] + #print('Replay constraints:', traceConstraints) + return traceConstraints + return None + + def processTraceConstraints(self, traceConstraints: list[tuple[TraceVar, str, str]], cfarNode: CFARNode) -> None: + + self.user.show_message('\nProcessing trace constraints ...\n') + + if self.trace_file: + # Write constraints to trace file for use in replay mode + tcl = [f'{tc[0].pretty} {tc[1]} {tc[2]}' for tc in traceConstraints] + self.trace_file.write('Trace Constraints: ') + json.dump(traceConstraints, self.trace_file, cls=TraceConstraintsJSONEncoder) + self.trace_file.write('\n') + self.trace_file.flush() + + # TODO: infrastructure to do this in the background on same thread as command loop + with io.StringIO() as out: + # input "[ [ { \"var\" : { \"symbolic_ident\" : 0 }, \"op\" : \"EQ\", \"const\" : \"128\"} ] ]" + # TODO: Handle multiple nodes in final result + out.write(r'input "[') + # TODO: Handle multiple eq conds + out.write(r'[') + for i, tc in enumerate(traceConstraints): + if i > 0: + out.write(r',') + out.write(r'{\"var\":{\"symbolic_ident\":') + # symbolic_ident + out.write(str(tc[0].symbolic_ident)) + out.write(r'},\"op\":\"') + # op + out.write(tc[1]) + out.write(r'\",\"const\":\"') + # int const + out.write(str(tc[2])) + out.write(r'\"}') + out.write(r']') + out.write(r']"') + verifierTraceConstraintInput = out.getvalue() + + #print('verifierTraceConstraintInput:', verifierTraceConstraintInput) + #self.debug_io = True + self._command('0') + # TODO: Consider generalizing command_loop rather than this processing? + #print('waiting for constraint prompt') + while True: + rec = self.next_json() + if isinstance(rec, dict) and rec['this'] == 'Waiting for constraints..': + break + else: + self.show_message(rec) + self._command(verifierTraceConstraintInput) + #print('waiting for regenerate result prompt') + while True: + rec = self.next_json() + if isinstance(rec, dict) and rec['this'] == 'Regenerate result with new trace constraints?': + break + else: + self.show_message(rec) + #print('waiting for constraint prompt') + self.processFinalResult(traceConstraints, cfarNode) + def show_message(self, rec: Any): if isinstance(rec, list): for m in rec: @@ -596,9 +701,12 @@ def __init__(self, id: str, desc: str, data: dict): self.external_postdomain = None self.addr = None self.finished = True + self.unconstrainedPredicate = None self.predicate = None self.trace_true = None self.trace_false = None + self.trace_footprint = None + self.traceConstraints = None self.instruction_trees = None def update_node(self, desc: str, data: dict): @@ -721,6 +829,70 @@ def get_parents(self, node: CFARNode) -> list[CFARNode]: return parents +class TraceVar: + def __init__(self, prefix: str, kind: str, raw: dict): + self.prefix = prefix + self.kind = kind + self.raw = raw + self.pretty = 'unknown' + self.numBits = 0 + self.type = None + self.symbolic_ident = None + + match self.kind: + case 'reg_op': + with io.StringIO() as out: + out.write(prefix) + out.write(" ") + pprint_reg(self.raw, out=out) + self.pretty = out.getvalue() + self.type = self.raw['val']['offset']['type'] + # TODO: parse numBits from type + self.numBits = 32 + case 'mem_op': + mem_op = raw['snd'] + with io.StringIO() as out: + out.write(prefix) + out.write(" ") + out.write(f'{get_addr_id(raw["fst"])}: {mem_op["direction"]} {get_value_id(mem_op["addr"])} ') + self.pretty = out.getvalue() + offset = mem_op['value']['offset'] + if isinstance(offset, dict): + self.symbolic_ident = offset['symbolic_ident'] + self.type = offset['type'] + self.numBits = mem_op['size'] * 8 + + +class TraceConstraintsJSONEncoder(JSONEncoder): + def default(self, o): + if isinstance(o, TraceVar): + return {'class': 'TraceVar', + 'prefix': o.prefix, + 'kind': o.kind, + 'raw': o.raw} + else: + return super().default(o) + + +def traceConstraintsJSONObjectHook(d: dict): + if d.get('class') == 'TraceVar': + return TraceVar(d['prefix'], d['kind'], d['raw']) + else: + return d + + +def extractTraceVars(rawFootprint) -> list[TraceVar]: + traceVars = [] + #for r in raw['fp_initial_regs']['reg_op']['map']: + # traceVars.append(TraceVar('reg_op', r)) + for r in rawFootprint['original']['fp_mem']: + traceVars.append(TraceVar('original', 'mem_op', r)) + for r in rawFootprint['patched']['fp_mem']: + traceVars.append(TraceVar('patched', 'mem_op', r)) + # TODO: sort by instruction addr, but reverse seems to work for now + traceVars.reverse() + return traceVars + def get_cfar_addr(cfar_id: str) -> tuple[Optional[int], Optional[int]]: """Get CFAR original and patched address""" parts = cfar_id.split(' vs ') @@ -1185,50 +1357,57 @@ def pprint_event_trace_initial_reg(initial_regs: dict, pre: str = '', out: IO = def pprint_reg_op(reg_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): - for r in reg_op['map']: - val: dict = r['val'] - ppval = get_value_id(val) - key: dict = r['key'] - if (not isinstance(val, dict) - or not prune_zero - or not ppval.startswith('0x0:')): - match key: - case {'arch_reg': name}: - if name == '_PC' and ppval.startswith('0x0:'): - # TODO: is this correct? - out.write(f'{pre}pc <- return address\n') - elif name in {'_PC', 'PSTATE_C', 'PSTATE_V', 'PSTATE_N', 'PSTATE_Z'}: - out.write(f'{pre}{name} <- {ppval}\n') - case {'reg': name}: + for reg in reg_op['map']: + pprint_reg(reg, pre, out, prune_zero) + + +def pprint_reg(reg: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): + val: dict = reg['val'] + ppval = get_value_id(val) + key: dict = reg['key'] + if (not isinstance(val, dict) + or not prune_zero + or not ppval.startswith('0x0:')): + match key: + case {'arch_reg': name}: + if name == '_PC' and ppval.startswith('0x0:'): + # TODO: is this correct? + out.write(f'{pre}pc <- return address\n') + elif name in {'_PC', 'PSTATE_C', 'PSTATE_V', 'PSTATE_N', 'PSTATE_Z'}: out.write(f'{pre}{name} <- {ppval}\n') - case {'hidden_reg': name}: - # drop for now - #out.write(f'{pre}Hidden Reg: {name}') - pass - case _: - out.write(f'{pre}{key} <- {ppval}\n') + case {'reg': name}: + out.write(f'{pre}{name} <- {ppval}\n') + case {'hidden_reg': name}: + # drop for now + #out.write(f'{pre}Hidden Reg: {name}') + pass + case _: + out.write(f'{pre}{key} <- {ppval}\n') -def pprint_mem_op(mem_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): - if mem_op.get('mem_op'): - mem_op = mem_op['mem_op'] - out.write(f'{pre}{mem_op["direction"]} {get_value_id(mem_op["addr"])} ') - match mem_op["direction"]: - case 'Read': - out.write('->') - case 'Write': - out.write('<-') - case _: - out.write('??') - out.write(f' {get_value_id(mem_op["value"])}') - #out.write(f' {mem_op["endianness"]}[{mem_op["size"]}]') - if mem_op['condition'] != '"unconditional"': - out.write(f' condition: {mem_op["condition"]}') - out.write('\n') - elif mem_op.get('external_call'): - out.write(f'{pre}{mem_op["external_call"]}({",".join(map(pretty_call_arg, mem_op["args"]))})\n') +def pprint_memory_op(memory_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): + if memory_op.get('mem_op'): + pprint_mem_op(memory_op['mem_op'], pre, out, prune_zero) + elif memory_op.get('external_call'): + out.write(f'{pre}{memory_op["external_call"]}({",".join(map(pretty_call_arg, memory_op["args"]))})\n') else: - out.write(f'{pre}Unknown mem op: {mem_op}') + out.write(f'{pre}Unknown mem op: {memory_op}') + + +def pprint_mem_op(mem_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): + out.write(f'{pre}{mem_op["direction"]} {get_value_id(mem_op["addr"])} ') + match mem_op["direction"]: + case 'Read': + out.write('->') + case 'Write': + out.write('<-') + case _: + out.write('??') + out.write(f' {get_value_id(mem_op["value"])}') + #out.write(f' {mem_op["endianness"]}[{mem_op["size"]}]') + if mem_op['condition'] != '"unconditional"': + out.write(f' condition: {mem_op["condition"]}') + out.write('\n') def pretty_call_arg(arg): @@ -1247,7 +1426,7 @@ def pprint_event_trace_instructions(events: dict, pre: str = '', out: IO = sys.s out.write(f'{pre} {get_addr_id(e["instruction_addr"])}\n') for op in e['events']: if op.get('memory_op'): - pprint_mem_op(op['memory_op'], pre + ' ', out) + pprint_memory_op(op['memory_op'], pre + ' ', out) elif op.get('register_op'): pprint_reg_op(op['register_op']['reg_op'], pre + ' ', out) @@ -1579,13 +1758,6 @@ def show_cfar_graph(self, graph: CFARGraph) -> None: print('Prompt Node:', promptNode.id) -def run_replay(file: str) -> Popen: - return Popen( - ['cat', file], - stdin=None, stdout=PIPE, text=True, encoding='utf-8' - ) - - def load_run_config(file: os.PathLike) -> Optional[dict]: try: with open(file, 'r') as f: @@ -1595,31 +1767,6 @@ def load_run_config(file: os.PathLike) -> Optional[dict]: return None -def run_config(config: dict): - cwd = config.get('cwd') - original = config.get('original') - patched = config.get('patched') - rawargs = config.get('args') - args = shlex.split(' '.join(rawargs)) - # TODO: Error checking - return run_pate(cwd, original, patched, args) - - -def run_pate(cwd: str, original: str, patched: str, args: list[str]) -> Popen: - # We use a helper script to run logic in the user's shell environment. - script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run-pate.sh") - # Need -l to make sure user's env is fully setup (e.g. access to docker and ghc tools). - return Popen(['/bin/bash', '-l', script, '-o', original, '-p', patched, '--json-toplevel'] + args, - cwd=cwd, - stdin=PIPE, stdout=PIPE, - stderr=STDOUT, - text=True, encoding='utf-8', - close_fds=True, - # Create a new process group, so we can kill it cleanly - preexec_fn=os.setsid - ) - - def get_demo_files(): files = [] demoDir = pathlib.Path(os.getenv('PATE_BINJA_DEMOS')) @@ -1642,3 +1789,5 @@ def run_pate_demo(): user = TtyUserInteraction(not replay) pate = PateWrapper(file, user) pate.run() + + diff --git a/pate_binja/view.py b/pate_binja/view.py index a376fe40..f465ab2b 100644 --- a/pate_binja/view.py +++ b/pate_binja/view.py @@ -21,7 +21,8 @@ from PySide6.QtCore import Qt, QCoreApplication from PySide6.QtGui import QMouseEvent, QAction, QColor, QPaintEvent from PySide6.QtWidgets import QHBoxLayout, QLabel, QVBoxLayout, QLineEdit, QPlainTextEdit, QDialog, QWidget, \ - QSplitter, QMenu, QTextEdit + QSplitter, QMenu, QTextEdit, QComboBox, QPushButton, QListWidget, QListWidgetItem, QAbstractItemView, \ + QDialogButtonBox, QMessageBox from .mcad.PateMcad import PateMcad, CycleCount from . import pate @@ -238,9 +239,11 @@ def show_cfar_graph(self, graph: pate.CFARGraph) -> None: promptNode = graph.getPromptNode() eqCondNodes = graph.getEqCondNodes() if promptNode: + #print('promptNode:', promptNode) self.pate_widget.flow_graph_widget.flowGraph.layout_and_wait() execute_on_main_thread_and_wait(lambda: self.pate_widget.flow_graph_widget.showCfars([promptNode])) elif eqCondNodes: + #print('eqCondNodes:', eqCondNodes) self.pate_widget.flow_graph_widget.flowGraph.layout_and_wait() execute_on_main_thread_and_wait(lambda: self.pate_widget.flow_graph_widget.showCfars(eqCondNodes)) @@ -335,7 +338,7 @@ def __init__(self, parent=None): self.setLayout(main_layout) def clear(self, msg): - self.linesA = None + self.linesA = [msg] self.labelA = None self.linesB = None self.labelB = None @@ -367,10 +370,12 @@ def redisplay(self): # Show diff html = generateHtmlDiff(self.linesA, self.labelA, self.linesB, self.labelB) self.diffField.setHtml(html) - elif self.linesA is None and self.linesB is not None: + elif self.linesA is not None and self.linesB is None: # Just linesA, no diff - text = self.labelA + "\n" - text += '\n'.join(self.lineA) + text = '' + if self.labelA is not None: + text += self.labelA + "\n" + text += '\n'.join(self.linesA) self.diffField.setText(text) else: # Nothing to show @@ -396,6 +401,12 @@ def __init__(self, parent): self.setLayout(main_layout) def setTrace(self, trace: dict, label: str = None): + if not trace: + # Unsat + self.domainField.setPlainText('Unsatisfiable') + self.traceDiff.clear('Unsatisfiable') + return + with io.StringIO() as out: pate.pprint_node_event_trace_domain(trace, out=out) self.domainField.setPlainText(out.getvalue()) @@ -455,11 +466,17 @@ def setTrace(self, trace: dict, label: str = None): class PateCfarEqCondDialog(QDialog): - def __init__(self, parent=None): + def __init__(self, cfarNode, parent=None): super().__init__(parent) + pw: Optional[PateWidget] = getAncestorInstanceOf(self, PateWidget) + + self.cfarNode = cfarNode + self.traceConstraints = None + self.resize(1500, 800) - self.setWindowTitle("Equivalence Condition") + self.setWindowTitle("") + self.setWindowTitle(f'Equivalence Condition - {self.cfarNode.id}') # Equivalence Condition Box self.eqCondField = QPlainTextEdit() @@ -471,6 +488,15 @@ def __init__(self, parent=None): eqCondBox = QWidget() eqCondBox.setLayout(eqCondBoxLayout) + # Constrain True Trace Button + if pw.pate_thread.pate_wrapper.trace_file is None: + # Replay mode + trueTraceConstraintButton = QPushButton("Constrain Trace (replay)") + else: + # Live Mode + trueTraceConstraintButton = QPushButton("Constrain Trace") + trueTraceConstraintButton.clicked.connect(lambda _: self.showTrueTraceConstraintDialog()) + # True Trace Box self.trueTraceWidget = TraceWidget(self) trueTraceBoxLayout = QVBoxLayout() @@ -500,15 +526,157 @@ def __init__(self, parent=None): mainSplitter.addWidget(trueFalseSplitter) # Main Layout - main_layout = QHBoxLayout() + main_layout = QVBoxLayout() main_layout.addWidget(mainSplitter) + main_layout.addWidget(trueTraceConstraintButton) self.setLayout(main_layout) - def setTrueTrace(self, trace: dict, label: str = None): - self.trueTraceWidget.setTrace(trace, label) + self.updateFromCfarNode() + + def updateFromCfarNode(self): + self.eqCondField.clear() + with io.StringIO() as out: + pate.pprint_symbolic(out, self.cfarNode.unconstrainedPredicate) + out.write('\n') + if self.cfarNode.traceConstraints: + #print(self.cfarNode.traceConstraints) + out.write('\nUser-supplied trace constraints:\n') + for tc in self.cfarNode.traceConstraints: + out.write(f'{tc[0].pretty} {tc[1]} {tc[2]}\n') + if self.cfarNode.trace_true or self.cfarNode.trace_false: + out.write('\nEffective equivalence condition after adding user-provided constraints::\n') + pate.pprint_symbolic(out, self.cfarNode.predicate) + else: + out.write('\nNo user-supplied trace constraints.\n') + self.eqCondField.appendPlainText(out.getvalue()) + self.trueTraceWidget.setTrace(self.cfarNode.trace_true) + self.falseTraceWidget.setTrace(self.cfarNode.trace_false) + + def showTrueTraceConstraintDialog(self): + pw: Optional[PateWidget] = getAncestorInstanceOf(self, PateWidget) + replayTraceConstraints = pw.pate_thread.pate_wrapper.getReplayTraceConstraints() + if replayTraceConstraints is None: + # Live - show dialog + d = PateTraceConstraintDialog(self.cfarNode, parent=self) + #d.setWindowTitle(f'{d.windowTitle()} - {cfarNode.id}') + if d.exec(): + self.traceConstraints = d.getConstraints() + #print(self.traceConstraints) + # TODO: Better way to do this? + pw.pate_thread.pate_wrapper.processTraceConstraints(self.traceConstraints, self.cfarNode) + self.updateFromCfarNode() + # TODO: report failed constraint? + else: + # Replay - skip dialog and replay constraints + pw.pate_thread.pate_wrapper.processTraceConstraints(replayTraceConstraints, self.cfarNode) + self.updateFromCfarNode() + +traceConstraintRelations = ["EQ", "NEQ", "LTs", "LTu", "GTs", "GTu", "LEs", "LEu", "GEs", "GEu"] + + +class PateTraceConstraintDialog(QDialog): + def __init__(self, cfarNode: pate.CFARNode, parent=None): + super().__init__(parent) + + self.cfarNode = cfarNode + + self.traceVars = pate.extractTraceVars(self.cfarNode.trace_footprint) + + # Prune TraceVars with no symbolic_ident + self.traceVars = [tv for tv in self.traceVars if tv.symbolic_ident is not None] - def setFalseTrace(self, trace: dict, label: str = None): - self.falseTraceWidget.setTrace(trace, label) + #self.resize(1500, 800) + self.setWindowTitle("Trace Constraint") + + self.varComboBox = QComboBox() + for tv in self.traceVars: + self.varComboBox.addItem(tv.pretty, userData=tv) + varLabel = QLabel("Variable:") + varLabel.setBuddy(self.varComboBox) + + self.relComboBox = QComboBox() + self.relComboBox.addItems(traceConstraintRelations) + relLabel = QLabel("Relation:") + relLabel.setBuddy(self.relComboBox) + + self.intTextLine = QLineEdit() + intLabel = QLabel("Integer:") + intLabel.setBuddy(self.intTextLine) + + addButton = QPushButton("Add") + addButton.clicked.connect(lambda _: self.addConstraint()) + + addLayout = QHBoxLayout() + addLayout.addSpacing(15) + addLayout.addWidget(varLabel) + addLayout.addWidget(self.varComboBox) + addLayout.addSpacing(15) + addLayout.addWidget(relLabel) + addLayout.addWidget(self.relComboBox) + addLayout.addSpacing(15) + addLayout.addWidget(intLabel) + addLayout.addWidget(self.intTextLine) + addLayout.addSpacing(40) + addLayout.addStretch(1) + addLayout.addWidget(addButton) + + self.constraintList = QListWidget() + self.constraintList.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) + + removeButton = QPushButton("Remove Selected") + removeButton.clicked.connect(lambda _: self.removeSelectedConstraints()) + + cancelButton = QPushButton("Cancel") + cancelButton.clicked.connect(lambda _: self.cancel()) + + applyButton = QPushButton("Apply") + applyButton.clicked.connect(lambda _: self.apply()) + + buttonBox = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + buttonBox.accepted.connect(self.accept) + buttonBox.rejected.connect(self.reject) + + # Main Layout + main_layout = QVBoxLayout() + main_layout.addLayout(addLayout) + main_layout.addWidget(self.constraintList) + main_layout.addWidget(removeButton) + main_layout.addWidget(buttonBox) + self.setLayout(main_layout) + + def addConstraint(self): + var = self.varComboBox.currentText() + traceVar = self.varComboBox.currentData() + rel = self.relComboBox.currentText() + intStr = self.intTextLine.text() + if not intStr: + QMessageBox.critical(self, "Trace Constraint Error", "No integer specified.") + return + try: + intVal = int(intStr, 0) + except ValueError: + QMessageBox.critical(self, "Trace Constraint Error", f'Can\'t parse "{intStr}" as an integer.') + return + + # TODO: Make sure intVal is in range for var type + # TODO: Prevent duplicates (wont hurt anything, but not useful to do and may mask entry error) + # TODO: Need data for constraint, associate with QListWidgetItem or subclass? Wait for var rep? + + constraint = f'{traceVar.pretty} {rel} {intVal}' + item = QListWidgetItem(constraint, self.constraintList) + item.setData(Qt.UserRole, (traceVar, rel, intVal)) + + def removeSelectedConstraints(self): + clist = self.constraintList + listItems = clist.selectedItems() + if not listItems: return + for item in listItems: + itemRow = clist.row(item) + clist.takeItem(itemRow) + + def getConstraints(self) -> list[tuple[pate.TraceVar, str, str]]: + lw = self.constraintList + return [lw.item(x).data(Qt.UserRole) for x in range(lw.count())] class InstTreeDiffWidget(QWidget): @@ -835,6 +1003,7 @@ def showCfars(self, cfars: list[pate.CFARNode]): #print('focusCfar.id', focusCfar.id) #print('focusFlowNode', focusFlow) if focusFlow: + #print('focus:', focusFlow) self.showNode(focusFlow) def mousePressEvent(self, event: QMouseEvent): @@ -874,7 +1043,7 @@ def nodePopupMenu(self, event: QMouseEvent, node: FlowGraphNode): action.triggered.connect(lambda _: self.pate_widget.gotoPatchedAddress(cfarNode.patched_addr)) menu.addAction(action) - if cfarNode.predicate: + if cfarNode.unconstrainedPredicate: action = QAction('Show Equivalence Condition', self) action.triggered.connect(lambda _: self.showCfarEqCondDialog(cfarNode)) menu.addAction(action) @@ -889,13 +1058,7 @@ def nodePopupMenu(self, event: QMouseEvent, node: FlowGraphNode): menu.exec_(event.globalPos()) def showCfarEqCondDialog(self, cfarNode: pate.CFARNode): - d = PateCfarEqCondDialog(parent=self) - d.setWindowTitle(f'{d.windowTitle()} - {cfarNode.id}') - with io.StringIO() as out: - pate.pprint_symbolic(out, cfarNode.predicate) - d.eqCondField.appendPlainText(out.getvalue()) - d.setTrueTrace(cfarNode.trace_true) - d.setFalseTrace(cfarNode.trace_false) + d = PateCfarEqCondDialog(cfarNode, parent=self) d.show() def edgePopupMenu(self, event: QMouseEvent, edgeTuple: tuple[FlowGraphEdge, bool]):