From 06160d429bdb45c584478445f9d537f0dc4f8169 Mon Sep 17 00:00:00 2001 From: disinvite Date: Sat, 13 Jan 2024 15:09:03 -0500 Subject: [PATCH] Name substitution for reccmp asm output --- .../isledecomp/compare/asm/__init__.py | 2 + .../isledecomp/compare/asm/parse.py | 152 ++++++ .../isledecomp/isledecomp/compare/asm/swap.py | 80 +++ tools/isledecomp/isledecomp/compare/core.py | 103 +++- tools/isledecomp/isledecomp/compare/db.py | 54 +- tools/isledecomp/isledecomp/cvdump/parser.py | 2 +- tools/isledecomp/isledecomp/utils.py | 14 - tools/isledecomp/tests/test_sanitize.py | 179 +++++++ tools/reccmp/reccmp.py | 482 ++++++------------ 9 files changed, 721 insertions(+), 347 deletions(-) create mode 100644 tools/isledecomp/isledecomp/compare/asm/__init__.py create mode 100644 tools/isledecomp/isledecomp/compare/asm/parse.py create mode 100644 tools/isledecomp/isledecomp/compare/asm/swap.py create mode 100644 tools/isledecomp/tests/test_sanitize.py diff --git a/tools/isledecomp/isledecomp/compare/asm/__init__.py b/tools/isledecomp/isledecomp/compare/asm/__init__.py new file mode 100644 index 00000000..3fd22f6e --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/asm/__init__.py @@ -0,0 +1,2 @@ +from .parse import ParseAsm +from .swap import can_resolve_register_differences diff --git a/tools/isledecomp/isledecomp/compare/asm/parse.py b/tools/isledecomp/isledecomp/compare/asm/parse.py new file mode 100644 index 00000000..ee9c9154 --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/asm/parse.py @@ -0,0 +1,152 @@ +"""Converts x86 machine code into text (i.e. assembly). The end goal is to +compare the code in the original and recomp binaries, using longest common +subsequence (LCS), i.e. difflib.SequenceMatcher. +The capstone library takes the raw bytes and gives us the mnemnonic +and operand(s) for each instruction. We need to "sanitize" the text further +so that virtual addresses are replaced by symbol name or a generic +placeholder string.""" + +import re +from typing import Callable, List, Optional, Tuple +from collections import namedtuple +from capstone import Cs, CS_ARCH_X86, CS_MODE_32 + +disassembler = Cs(CS_ARCH_X86, CS_MODE_32) + +ptr_replace_regex = re.compile(r"ptr \[(0x[0-9a-fA-F]+)\]") + +DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str") + + +def from_hex(string: str) -> Optional[int]: + try: + return int(string, 16) + except ValueError: + pass + + return None + + +class ParseAsm: + def __init__( + self, + relocate_lookup: Optional[Callable[[int], bool]] = None, + name_lookup: Optional[Callable[[int], str]] = None, + ) -> None: + self.relocate_lookup = relocate_lookup + self.name_lookup = name_lookup + self.replacements = {} + self.number_placeholders = True + + def reset(self): + self.replacements = {} + + def is_relocated(self, addr: int) -> bool: + if callable(self.relocate_lookup): + return self.relocate_lookup(addr) + + return False + + def lookup(self, addr: int) -> Optional[str]: + """Return a replacement name for this address if we find one.""" + if (cached := self.replacements.get(addr, None)) is not None: + return cached + + if callable(self.name_lookup): + if (name := self.name_lookup(addr)) is not None: + self.replacements[addr] = name + return name + + return None + + def replace(self, addr: int) -> str: + """Same function as lookup above, but here we return a placeholder + if there is no better name to use.""" + if (name := self.lookup(addr)) is not None: + return name + + # The placeholder number corresponds to the number of addresses we have + # already replaced. This is so the number will be consistent across the diff + # if we can replace some symbols with actual names in recomp but not orig. + idx = len(self.replacements) + 1 + placeholder = f"" if self.number_placeholders else "" + self.replacements[addr] = placeholder + return placeholder + + def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]: + if len(inst.op_str) == 0: + # Nothing to sanitize + return (inst.mnemonic, "") + + # For jumps or calls, if the entire op_str is a hex number, the value + # is a relative offset. + # Otherwise (i.e. it looks like `dword ptr [address]`) it is an + # absolute indirect that we will handle below. + # Providing the starting address of the function to capstone.disasm has + # automatically resolved relative offsets to an absolute address. + # We will have to undo this for some of the jumps or they will not match. + op_str_address = from_hex(inst.op_str) + + if op_str_address is not None: + if inst.mnemonic == "call": + return (inst.mnemonic, self.replace(op_str_address)) + + if inst.mnemonic == "jmp": + # The unwind section contains JMPs to other functions. + # If we have a name for this address, use it. If not, + # do not create a new placeholder. We will instead + # fall through to generic jump handling below. + potential_name = self.lookup(op_str_address) + if potential_name is not None: + return (inst.mnemonic, potential_name) + + if inst.mnemonic.startswith("j"): + # i.e. if this is any jump + # Show the jump offset rather than the absolute address + jump_displacement = op_str_address - (inst.address + inst.size) + return (inst.mnemonic, hex(jump_displacement)) + + def filter_out_ptr(match): + """Helper for re.sub, see below""" + offset = from_hex(match.group(1)) + + if offset is not None: + # We assume this is always an address to replace + placeholder = self.replace(offset) + return f"ptr [{placeholder}]" + + # Strict regex should ensure we can read the hex number. + # But just in case: return the string with no changes + return match.group(0) + + op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str) + + # Performance hack: + # Skip this step if there is nothing left to consider replacing. + if "0x" in op_str: + # Replace immediate values with name or placeholder (where appropriate) + words = op_str.split(", ") + for i, word in enumerate(words): + try: + inttest = int(word, 16) + # If this value is a virtual address, it is referenced absolutely, + # which means it must be in the relocation table. + if self.is_relocated(inttest): + words[i] = self.replace(inttest) + except ValueError: + pass + op_str = ", ".join(words) + + return inst.mnemonic, op_str + + def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]: + asm = [] + + for inst in disassembler.disasm_lite(data, start_addr): + # Use heuristics to disregard some differences that aren't representative + # of the accuracy of a function (e.g. global offsets) + result = self.sanitize(DisasmLiteInst(*inst)) + # mnemonic + " " + op_str + asm.append(" ".join(result)) + + return asm diff --git a/tools/isledecomp/isledecomp/compare/asm/swap.py b/tools/isledecomp/isledecomp/compare/asm/swap.py new file mode 100644 index 00000000..599444cf --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/asm/swap.py @@ -0,0 +1,80 @@ +import re + +REGISTER_LIST = set( + [ + "ax", + "bp", + "bx", + "cx", + "di", + "dx", + "eax", + "ebp", + "ebx", + "ecx", + "edi", + "edx", + "esi", + "esp", + "si", + "sp", + ] +) +WORDS = re.compile(r"\w+") + + +def get_registers(line: str): + to_replace = [] + # use words regex to find all matching positions: + for match in WORDS.finditer(line): + reg = match.group(0) + if reg in REGISTER_LIST: + to_replace.append((reg, match.start())) + return to_replace + + +def replace_register( + lines: list[str], start_line: int, reg: str, replacement: str +) -> list[str]: + return [ + line.replace(reg, replacement) if i >= start_line else line + for i, line in enumerate(lines) + ] + + +# Is it possible to make new_asm the same as original_asm by swapping registers? +def can_resolve_register_differences(original_asm, new_asm): + # Split the ASM on spaces to get more granularity, and so + # that we don't modify the original arrays passed in. + original_asm = [part for line in original_asm for part in line.split()] + new_asm = [part for line in new_asm for part in line.split()] + + # Swapping ain't gonna help if the lengths are different + if len(original_asm) != len(new_asm): + return False + + # Look for the mismatching lines + for i, original_line in enumerate(original_asm): + new_line = new_asm[i] + if new_line != original_line: + # Find all the registers to replace + to_replace = get_registers(original_line) + + for replace in to_replace: + (reg, reg_index) = replace + replacing_reg = new_line[reg_index : reg_index + len(reg)] + if replacing_reg in REGISTER_LIST: + if replacing_reg != reg: + # Do a three-way swap replacing in all the subsequent lines + temp_reg = "&" * len(reg) + new_asm = replace_register(new_asm, i, replacing_reg, temp_reg) + new_asm = replace_register(new_asm, i, reg, replacing_reg) + new_asm = replace_register(new_asm, i, temp_reg, reg) + else: + # No replacement to do, different code, bail out + return False + # Check if the lines are now the same + for i, original_line in enumerate(original_asm): + if new_asm[i] != original_line: + return False + return True diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index 07860203..1e58d2cb 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -1,11 +1,14 @@ import os import logging -from typing import List, Optional +import difflib +from dataclasses import dataclass +from typing import Iterable, List, Optional from isledecomp.cvdump.demangler import demangle_string_const from isledecomp.cvdump import Cvdump, CvdumpAnalysis from isledecomp.parser import DecompCodebase from isledecomp.dir import walk_source_dir from isledecomp.types import SymbolType +from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences from .db import CompareDb, MatchInfo from .lines import LinesDb @@ -13,6 +16,24 @@ logger = logging.getLogger(__name__) +@dataclass +class DiffReport: + orig_addr: int + recomp_addr: int + name: str + udiff: Optional[List[str]] = None + ratio: float = 0.0 + is_effective_match: bool = False + + @property + def effective_ratio(self) -> float: + return 1.0 if self.is_effective_match else self.ratio + + def __str__(self) -> str: + """For debug purposes. Proper diff printing (with coloring) is in another module.""" + return f"{self.name} (0x{self.orig_addr:x}) {self.ratio*100:.02f}%{'*' if self.is_effective_match else ''}" + + class Compare: # pylint: disable=too-many-instance-attributes def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir): @@ -133,8 +154,84 @@ def get_one_function(self, addr: int) -> Optional[MatchInfo]: def get_functions(self) -> List[MatchInfo]: return self._db.get_matches(SymbolType.FUNCTION) - def compare_functions(self): - pass + def _compare_function(self, match: MatchInfo) -> DiffReport: + if match.size == 0: + # Report a failed match to make the user aware of the empty function. + return DiffReport( + orig_addr=match.orig_addr, + recomp_addr=match.recomp_addr, + name=match.name, + ) + + orig_raw = self.orig_bin.read(match.orig_addr, match.size) + recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) + + def orig_should_replace(addr: int) -> bool: + return addr > self.orig_bin.imagebase and self.orig_bin.is_relocated_addr( + addr + ) + + def recomp_should_replace(addr: int) -> bool: + return ( + addr > self.recomp_bin.imagebase + and self.recomp_bin.is_relocated_addr(addr) + ) + + def orig_lookup(addr: int) -> Optional[str]: + m = self._db.get_by_orig(addr) + if m is None: + return None + + return m.match_name() + + def recomp_lookup(addr: int) -> Optional[str]: + m = self._db.get_by_recomp(addr) + if m is None: + return None + + return m.match_name() + + orig_parse = ParseAsm( + relocate_lookup=orig_should_replace, name_lookup=orig_lookup + ) + recomp_parse = ParseAsm( + relocate_lookup=recomp_should_replace, name_lookup=recomp_lookup + ) + + orig_asm = orig_parse.parse_asm(orig_raw, match.orig_addr) + recomp_asm = recomp_parse.parse_asm(recomp_raw, match.recomp_addr) + + diff = difflib.SequenceMatcher(None, orig_asm, recomp_asm) + ratio = diff.ratio() + + if ratio != 1.0: + # Check whether we can resolve register swaps which are actually + # perfect matches modulo compiler entropy. + is_effective_match = can_resolve_register_differences(orig_asm, recomp_asm) + unified_diff = difflib.unified_diff(orig_asm, recomp_asm, n=10) + else: + is_effective_match = False + unified_diff = [] + + return DiffReport( + orig_addr=match.orig_addr, + recomp_addr=match.recomp_addr, + name=match.name, + udiff=unified_diff, + ratio=ratio, + is_effective_match=is_effective_match, + ) + + def compare_function(self, addr: int) -> Optional[DiffReport]: + match = self.get_one_function(addr) + if match is None: + return None + + return self._compare_function(match) + + def compare_functions(self) -> Iterable[DiffReport]: + for match in self.get_functions(): + yield self._compare_function(match) def compare_variables(self): pass diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py index 850f25fd..3cd25bf7 100644 --- a/tools/isledecomp/isledecomp/compare/db.py +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -2,7 +2,6 @@ addresses/symbols that we want to compare between the original and recompiled binaries.""" import sqlite3 import logging -from collections import namedtuple from typing import List, Optional from isledecomp.types import SymbolType @@ -16,12 +15,35 @@ size int, should_skip int default(FALSE) ); + CREATE INDEX `symbols_or` ON `symbols` (orig_addr); CREATE INDEX `symbols_re` ON `symbols` (recomp_addr); - CREATE INDEX `symbols_na` ON `symbols` (compare_type, name); """ -MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name") +class MatchInfo: + def __init__( + self, + ctype: Optional[int], + orig: Optional[int], + recomp: Optional[int], + name: Optional[str], + size: Optional[int], + ) -> None: + self.compare_type = SymbolType(ctype) if ctype is not None else None + self.orig_addr = orig + self.recomp_addr = recomp + self.name = name + self.size = size + + def match_name(self) -> str: + """Combination of the name and compare type. + Intended for name substitution in the diff. If there is a diff, + it will be more obvious what this symbol indicates.""" + if self.name is None: + return None + + ctype = self.compare_type.name if self.compare_type is not None else "UNK" + return f"{self.name} ({ctype})" def matchinfo_factory(_, row): @@ -61,7 +83,7 @@ def get_unmatched_strings(self) -> List[str]: def get_one_function(self, addr: int) -> Optional[MatchInfo]: cur = self._db.execute( - """SELECT orig_addr, recomp_addr, size, name + """SELECT compare_type, orig_addr, recomp_addr, name, size FROM `symbols` WHERE compare_type = ? AND orig_addr = ? @@ -74,9 +96,31 @@ def get_one_function(self, addr: int) -> Optional[MatchInfo]: cur.row_factory = matchinfo_factory return cur.fetchone() + def get_by_orig(self, addr: int) -> Optional[MatchInfo]: + cur = self._db.execute( + """SELECT compare_type, orig_addr, recomp_addr, name, size + FROM `symbols` + WHERE orig_addr = ? + """, + (addr,), + ) + cur.row_factory = matchinfo_factory + return cur.fetchone() + + def get_by_recomp(self, addr: int) -> Optional[MatchInfo]: + cur = self._db.execute( + """SELECT compare_type, orig_addr, recomp_addr, name, size + FROM `symbols` + WHERE recomp_addr = ? + """, + (addr,), + ) + cur.row_factory = matchinfo_factory + return cur.fetchone() + def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]: cur = self._db.execute( - """SELECT orig_addr, recomp_addr, size, name + """SELECT compare_type, orig_addr, recomp_addr, name, size FROM `symbols` WHERE compare_type = ? AND orig_addr IS NOT NULL diff --git a/tools/isledecomp/isledecomp/cvdump/parser.py b/tools/isledecomp/isledecomp/cvdump/parser.py index 613cd4a4..d14abe88 100644 --- a/tools/isledecomp/isledecomp/cvdump/parser.py +++ b/tools/isledecomp/isledecomp/cvdump/parser.py @@ -36,7 +36,7 @@ # e.g. `S_GDATA32: [0003:000004A4], Type: T_32PRCHAR(0470), g_set` _gdata32_regex = re.compile( - r"S_GDATA32: \[(?P
\w{4}):(?P\w{8})\], Type:\s*(?P\S+), (?P\S+)" + r"S_GDATA32: \[(?P
\w{4}):(?P\w{8})\], Type:\s*(?P\S+), (?P.+)" ) diff --git a/tools/isledecomp/isledecomp/utils.py b/tools/isledecomp/isledecomp/utils.py index ce4896fd..637eee33 100644 --- a/tools/isledecomp/isledecomp/utils.py +++ b/tools/isledecomp/isledecomp/utils.py @@ -26,17 +26,3 @@ def print_diff(udiff, plain): def get_file_in_script_dir(fn): return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn) - - -class OffsetPlaceholderGenerator: - def __init__(self): - self.counter = 0 - self.replacements = {} - - def get(self, replace_addr): - if replace_addr in self.replacements: - return self.replacements[replace_addr] - self.counter += 1 - replacement = f"" - self.replacements[replace_addr] = replacement - return replacement diff --git a/tools/isledecomp/tests/test_sanitize.py b/tools/isledecomp/tests/test_sanitize.py new file mode 100644 index 00000000..67d95b4f --- /dev/null +++ b/tools/isledecomp/tests/test_sanitize.py @@ -0,0 +1,179 @@ +from typing import Optional +import pytest +from isledecomp.compare.asm.parse import DisasmLiteInst, ParseAsm + + +def mock_inst(mnemonic: str, op_str: str) -> DisasmLiteInst: + """Mock up the named tuple DisasmLite from just a mnemonic and op_str. + To be used for tests on sanitize that do not require the instruction address + or size. i.e. any non-jump instruction.""" + return DisasmLiteInst(0, 0, mnemonic, op_str) + + +identity_cases = [ + ("", ""), + ("sti", ""), + ("push", "ebx"), + ("ret", ""), + ("ret", "4"), + ("mov", "eax, 0x1234"), +] + + +@pytest.mark.parametrize("mnemonic, op_str", identity_cases) +def test_identity(mnemonic, op_str): + """Confirm that nothing is substituted.""" + p = ParseAsm() + inst = mock_inst(mnemonic, op_str) + result = p.sanitize(inst) + assert result == (mnemonic, op_str) + + +ptr_replace_cases = [ + ("byte ptr [0x5555]", "byte ptr []"), + ("word ptr [0x5555]", "word ptr []"), + ("dword ptr [0x5555]", "dword ptr []"), + ("qword ptr [0x5555]", "qword ptr []"), + ("eax, dword ptr [0x5555]", "eax, dword ptr []"), + ("dword ptr [0x5555], eax", "dword ptr [], eax"), + ("dword ptr [0x5555], 0", "dword ptr [], 0"), + ("dword ptr [0x5555], 8", "dword ptr [], 8"), + # Same value, assumed to be an addr in the first appearance + # because it is designated as 'ptr', but we have not provided the + # relocation table lookup method so we do not replace the second appearance. + ("dword ptr [0x5555], 0x5555", "dword ptr [], 0x5555"), +] + + +@pytest.mark.parametrize("start, end", ptr_replace_cases) +def test_ptr_replace(start, end): + """Anything in square brackets (with the 'ptr' prefix) will always be replaced.""" + p = ParseAsm() + inst = mock_inst("", start) + (_, op_str) = p.sanitize(inst) + assert op_str == end + + +call_replace_cases = [ + ("ebx", "ebx"), + ("0x1234", ""), + ("dword ptr [0x1234]", "dword ptr []"), + ("dword ptr [ecx + 0x10]", "dword ptr [ecx + 0x10]"), +] + + +@pytest.mark.parametrize("start, end", call_replace_cases) +def test_call_replace(start, end): + """Call with hex operand is always replaced. + Otherwise, ptr replacement rules apply, but skip `this` calls.""" + p = ParseAsm() + inst = mock_inst("call", start) + (_, op_str) = p.sanitize(inst) + assert op_str == end + + +def test_jump_displacement(): + """Display jump displacement (offset from end of jump instruction) + instead of destination address.""" + p = ParseAsm() + inst = DisasmLiteInst(0x1000, 2, "je", "0x1000") + (_, op_str) = p.sanitize(inst) + assert op_str == "-0x2" + + +@pytest.mark.xfail(reason="Not implemented yet") +def test_jmp_table(): + """Should detect the characteristic jump table instruction + (for a switch statement) and use placeholder.""" + p = ParseAsm() + inst = mock_inst("jmp", "dword ptr [eax*4 + 0x5555]") + (_, op_str) = p.sanitize(inst) + assert op_str == "dword ptr [eax*4 + ]" + + +name_replace_cases = [ + ("byte ptr [0x5555]", "byte ptr [_substitute_]"), + ("word ptr [0x5555]", "word ptr [_substitute_]"), + ("dword ptr [0x5555]", "dword ptr [_substitute_]"), + ("qword ptr [0x5555]", "qword ptr [_substitute_]"), +] + + +@pytest.mark.parametrize("start, end", name_replace_cases) +def test_name_replace(start, end): + """Make sure the name lookup function is called if present""" + + def substitute(_: int) -> str: + return "_substitute_" + + p = ParseAsm(name_lookup=substitute) + inst = mock_inst("mov", start) + (_, op_str) = p.sanitize(inst) + assert op_str == end + + +def test_replacement_cache(): + p = ParseAsm() + inst = mock_inst("inc", "dword ptr [0x1234]") + + (_, op_str) = p.sanitize(inst) + assert op_str == "dword ptr []" + + (_, op_str) = p.sanitize(inst) + assert op_str == "dword ptr []" + + +def test_replacement_numbering(): + """If we can use the name lookup for the first address but not the second, + the second replacement should be not .""" + + def substitute_1234(addr: int) -> Optional[str]: + return "_substitute_" if addr == 0x1234 else None + + p = ParseAsm(name_lookup=substitute_1234) + + (_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x1234]")) + assert op_str == "dword ptr [_substitute_]" + + (_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x5555]")) + assert op_str == "dword ptr []" + + +def test_relocate_lookup(): + """Immediate values would be relocated if they are actually addresses. + So we can use the relocation table to check whether a given value is an + address or just some number.""" + + def relocate_lookup(addr: int) -> bool: + return addr == 0x1234 + + p = ParseAsm(relocate_lookup=relocate_lookup) + (_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x1234")) + assert op_str == "eax, " + + (_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x5555")) + assert op_str == "eax, 0x5555" + + +def test_jump_to_function(): + """A jmp instruction can lead us directly to a function. This can be found + in the unwind section at the end of a function. However: we do not want to + assume this is the case for all jumps. Only replace the jump with a name + if we can find it using our lookup.""" + + def substitute_1234(addr: int) -> Optional[str]: + return "_substitute_" if addr == 0x1234 else None + + p = ParseAsm(name_lookup=substitute_1234) + inst = DisasmLiteInst(0x1000, 2, "jmp", "0x1234") + (_, op_str) = p.sanitize(inst) + assert op_str == "_substitute_" + + # Should not replace this jump. + # 0x1000 (start addr) + # + 2 (size of jump instruction) + # + 0x5555 (displacement, the value we want) + # = 0x6557 + inst = DisasmLiteInst(0x1000, 2, "jmp", "0x6557") + (_, op_str) = p.sanitize(inst) + assert op_str == "0x5555" diff --git a/tools/reccmp/reccmp.py b/tools/reccmp/reccmp.py index b7027435..6b688890 100755 --- a/tools/reccmp/reccmp.py +++ b/tools/reccmp/reccmp.py @@ -2,167 +2,20 @@ import argparse import base64 -import difflib import json import logging import os -import re from isledecomp import ( Bin, get_file_in_script_dir, - OffsetPlaceholderGenerator, print_diff, ) from isledecomp.compare import Compare as IsleCompare - -from capstone import Cs, CS_ARCH_X86, CS_MODE_32 -import colorama from pystache import Renderer +import colorama - -REGISTER_LIST = set( - [ - "ax", - "bp", - "bx", - "cx", - "di", - "dx", - "eax", - "ebp", - "ebx", - "ecx", - "edi", - "edx", - "esi", - "esp", - "si", - "sp", - ] -) -WORDS = re.compile(r"\w+") - - -def sanitize(file, placeholder_generator, mnemonic, op_str): - op_str_is_number = False - try: - int(op_str, 16) - op_str_is_number = True - except ValueError: - pass - - if (mnemonic in ["call", "jmp"]) and op_str_is_number: - # Filter out "calls" because the offsets we're not currently trying to - # match offsets. As long as there's a call in the right place, it's - # probably accurate. - op_str = placeholder_generator.get(int(op_str, 16)) - else: - - def filter_out_ptr(ptype, op_str): - try: - ptrstr = ptype + " ptr [" - start = op_str.index(ptrstr) + len(ptrstr) - end = op_str.index("]", start) - - # This will throw ValueError if not hex - inttest = int(op_str[start:end], 16) - - return ( - op_str[0:start] + placeholder_generator.get(inttest) + op_str[end:] - ) - except ValueError: - return op_str - - # Filter out dword ptrs where the pointer is to an offset - op_str = filter_out_ptr("dword", op_str) - op_str = filter_out_ptr("word", op_str) - op_str = filter_out_ptr("byte", op_str) - - # Use heuristics to filter out any args that look like offsets - words = op_str.split(" ") - for i, word in enumerate(words): - try: - inttest = int(word, 16) - if file.is_relocated_addr(inttest): - words[i] = placeholder_generator.get(inttest) - except ValueError: - pass - op_str = " ".join(words) - - return mnemonic, op_str - - -def parse_asm(disassembler, file, asm_addr, size): - asm = [] - data = file.read(asm_addr, size) - placeholder_generator = OffsetPlaceholderGenerator() - for i in disassembler.disasm(data, 0): - # Use heuristics to disregard some differences that aren't representative - # of the accuracy of a function (e.g. global offsets) - mnemonic, op_str = sanitize(file, placeholder_generator, i.mnemonic, i.op_str) - if op_str is None: - asm.append(mnemonic) - else: - asm.append(f"{mnemonic} {op_str}") - return asm - - -def get_registers(line: str): - to_replace = [] - # use words regex to find all matching positions: - for match in WORDS.finditer(line): - reg = match.group(0) - if reg in REGISTER_LIST: - to_replace.append((reg, match.start())) - return to_replace - - -def replace_register( - lines: list[str], start_line: int, reg: str, replacement: str -) -> list[str]: - return [ - line.replace(reg, replacement) if i >= start_line else line - for i, line in enumerate(lines) - ] - - -# Is it possible to make new_asm the same as original_asm by swapping registers? -def can_resolve_register_differences(original_asm, new_asm): - # Split the ASM on spaces to get more granularity, and so - # that we don't modify the original arrays passed in. - original_asm = [part for line in original_asm for part in line.split()] - new_asm = [part for line in new_asm for part in line.split()] - - # Swapping ain't gonna help if the lengths are different - if len(original_asm) != len(new_asm): - return False - - # Look for the mismatching lines - for i, original_line in enumerate(original_asm): - new_line = new_asm[i] - if new_line != original_line: - # Find all the registers to replace - to_replace = get_registers(original_line) - - for replace in to_replace: - (reg, reg_index) = replace - replacing_reg = new_line[reg_index : reg_index + len(reg)] - if replacing_reg in REGISTER_LIST: - if replacing_reg != reg: - # Do a three-way swap replacing in all the subsequent lines - temp_reg = "&" * len(reg) - new_asm = replace_register(new_asm, i, replacing_reg, temp_reg) - new_asm = replace_register(new_asm, i, reg, replacing_reg) - new_asm = replace_register(new_asm, i, temp_reg, reg) - else: - # No replacement to do, different code, bail out - return False - # Check if the lines are now the same - for i, original_line in enumerate(original_asm): - if new_asm[i] != original_line: - return False - return True +colorama.init() def gen_html(html_file, data): @@ -197,9 +50,88 @@ def gen_svg(svg_file, name_svg, icon, svg_implemented_funcs, total_funcs, raw_ac svgfile.write(output_data) -# Do the actual work -def main(): - # pylint: disable=too-many-locals, too-many-nested-blocks, too-many-branches, too-many-statements +def get_percent_color(value: float) -> str: + """Return colorama ANSI escape character for the given decimal value.""" + if value == 1.0: + return colorama.Fore.GREEN + if value > 0.8: + return colorama.Fore.YELLOW + + return colorama.Fore.RED + + +def percent_string( + ratio: float, is_effective: bool = False, is_plain: bool = False +) -> str: + """Helper to construct a percentage string from the given ratio. + If is_effective (i.e. effective match), indicate that with the asterisk. + If is_plain, don't use colorama ANSI codes.""" + + percenttext = f"{(ratio * 100):.2f}%" + effective_star = "*" if is_effective else "" + + if is_plain: + return percenttext + effective_star + + return "".join( + [ + get_percent_color(ratio), + percenttext, + colorama.Fore.RED if is_effective else "", + effective_star, + colorama.Style.RESET_ALL, + ] + ) + + +def print_match_verbose(match, show_both_addrs: bool = False, is_plain: bool = False): + percenttext = percent_string( + match.effective_ratio, match.is_effective_match, is_plain + ) + + if show_both_addrs: + addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}" + else: + addrs = hex(match.orig_addr) + + if match.effective_ratio == 1.0: + ok_text = ( + "OK!" + if is_plain + else (colorama.Fore.GREEN + "✨ OK! ✨" + colorama.Style.RESET_ALL) + ) + if match.ratio == 1.0: + print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n") + else: + print( + f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n" + ) + else: + print_diff(match.udiff, is_plain) + + print( + f"\n{match.name} is only {percenttext} similar to the original, diff above" + ) + + +def print_match_oneline(match, show_both_addrs: bool = False, is_plain: bool = False): + percenttext = percent_string( + match.effective_ratio, match.is_effective_match, is_plain + ) + + if show_both_addrs: + addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}" + else: + addrs = hex(match.orig_addr) + + print(f" {match.name} ({addrs}) is {percenttext} similar to the original") + + +def parse_args() -> argparse.Namespace: + def virtual_address(value) -> int: + """Helper method for argparse, verbose parameter""" + return int(value, 16) + parser = argparse.ArgumentParser( allow_abbrev=False, description="Recompilation Compare: compare an original EXE with a recompiled EXE + PDB.", @@ -226,6 +158,7 @@ def main(): "--verbose", "-v", metavar="", + type=virtual_address, help="Print assembly diff for specific function (original file's offset)", ) parser.add_argument( @@ -258,198 +191,99 @@ def main(): args = parser.parse_args() + if not os.path.isfile(args.original): + parser.error(f"Original binary {args.original} does not exist") + + if not os.path.isfile(args.recompiled): + parser.error(f"Recompiled binary {args.recompiled} does not exist") + + if not os.path.isfile(args.pdb): + parser.error(f"Symbols PDB {args.pdb} does not exist") + + if not os.path.isdir(args.decomp_dir): + parser.error(f"Source directory {args.decomp_dir} does not exist") + + return args + + +def main(): + args = parse_args() logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s") - colorama.init() - - verbose = None - found_verbose_target = False - if args.verbose: - try: - verbose = int(args.verbose, 16) - except ValueError: - parser.error("invalid verbose argument") - html_path = args.html - - plain = args.no_color - - original = args.original - if not os.path.isfile(original): - parser.error(f"Original binary {original} does not exist") - - recomp = args.recompiled - if not os.path.isfile(recomp): - parser.error(f"Recompiled binary {recomp} does not exist") - - syms = args.pdb - if not os.path.isfile(syms): - parser.error(f"Symbols PDB {syms} does not exist") - - source = args.decomp_dir - if not os.path.isdir(source): - parser.error(f"Source directory {source} does not exist") - - svg = args.svg - - with Bin(original, find_str=True) as origfile, Bin(recomp) as recompfile: - if verbose is not None: + with Bin(args.original, find_str=True) as origfile, Bin( + args.recompiled + ) as recompfile: + if args.verbose is not None: # Mute logger events from compare engine logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL) logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL) - isle_compare = IsleCompare(origfile, recompfile, syms, source) + isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir) print() - capstone_disassembler = Cs(CS_ARCH_X86, CS_MODE_32) + ### Compare one or none. + + if args.verbose is not None: + match = isle_compare.compare_function(args.verbose) + if match is None: + print(f"Failed to find the function with address 0x{args.verbose:x}") + return + + print_match_verbose(match, is_plain=args.no_color) + return + + ### Compare everything. function_count = 0 total_accuracy = 0 total_effective_accuracy = 0 htmlinsert = [] - matches = [] - if verbose is not None: - match = isle_compare.get_one_function(verbose) - if match is not None: - found_verbose_target = True - matches = [match] - else: - matches = isle_compare.get_functions() - - for match in matches: - # The effective_ratio is the ratio when ignoring differing register - # allocation vs the ratio is the true ratio. - ratio = 0.0 - effective_ratio = 0.0 - if match.size: - origasm = parse_asm( - capstone_disassembler, - origfile, - match.orig_addr, - match.size, - ) - recompasm = parse_asm( - capstone_disassembler, - recompfile, - match.recomp_addr, - match.size, - ) - - diff = difflib.SequenceMatcher(None, origasm, recompasm) - ratio = diff.ratio() - effective_ratio = ratio - - if ratio != 1.0: - # Check whether we can resolve register swaps which are actually - # perfect matches modulo compiler entropy. - if can_resolve_register_differences(origasm, recompasm): - effective_ratio = 1.0 - else: - ratio = 0 - - percenttext = f"{(effective_ratio * 100):.2f}%" - if not plain: - if effective_ratio == 1.0: - percenttext = ( - colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL - ) - elif effective_ratio > 0.8: - percenttext = ( - colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL - ) - else: - percenttext = ( - colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL - ) - - if effective_ratio == 1.0 and ratio != 1.0: - if plain: - percenttext += "*" - else: - percenttext += colorama.Fore.RED + "*" + colorama.Style.RESET_ALL - - if args.print_rec_addr: - addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}" - else: - addrs = hex(match.orig_addr) - - if not verbose: - print( - f" {match.name} ({addrs}) is {percenttext} similar to the original" - ) + for match in isle_compare.compare_functions(): + print_match_oneline(match, is_plain=args.no_color) function_count += 1 - total_accuracy += ratio - total_effective_accuracy += effective_ratio + total_accuracy += match.ratio + total_effective_accuracy += match.effective_ratio - if match.size: - udiff = difflib.unified_diff(origasm, recompasm, n=10) - - # If verbose, print the diff for that function to the output - if verbose: - if effective_ratio == 1.0: - ok_text = ( - "OK!" - if plain - else ( - colorama.Fore.GREEN - + "✨ OK! ✨" - + colorama.Style.RESET_ALL - ) - ) - if ratio == 1.0: - print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n") - else: - print( - f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n" - ) - else: - print_diff(udiff, plain) - - print( - f"\n{match.name} is only {percenttext} similar to the original, diff above" - ) - - # If html, record the diffs to an HTML file - if html_path: - htmlinsert.append( - { - "address": f"0x{match.orig_addr:x}", - "name": match.name, - "matching": effective_ratio, - "diff": "\n".join(udiff), - } - ) - - if html_path: - gen_html(html_path, json.dumps(htmlinsert)) - - if verbose: - if not found_verbose_target: - print(f"Failed to find the function with address 0x{verbose:x}") - else: - implemented_funcs = function_count - - if args.total: - function_count = int(args.total) - - if function_count > 0: - effective_accuracy = total_effective_accuracy / function_count * 100 - actual_accuracy = total_accuracy / function_count * 100 - print( - f"\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)" + # If html, record the diffs to an HTML file + if args.html is not None: + htmlinsert.append( + { + "address": f"0x{match.orig_addr:x}", + "name": match.name, + "matching": match.effective_ratio, + "diff": "\n".join(match.udiff), + } ) - if svg: - gen_svg( - svg, - os.path.basename(original), - args.svg_icon, - implemented_funcs, - function_count, - total_effective_accuracy, - ) + ## Generate files and show summary. + + if args.html is not None: + gen_html(args.html, json.dumps(htmlinsert)) + + implemented_funcs = function_count + + if args.total: + function_count = int(args.total) + + if function_count > 0: + effective_accuracy = total_effective_accuracy / function_count * 100 + actual_accuracy = total_accuracy / function_count * 100 + print( + f"\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)" + ) + + if args.svg is not None: + gen_svg( + args.svg, + os.path.basename(args.original), + args.svg_icon, + implemented_funcs, + function_count, + total_effective_accuracy, + ) if __name__ == "__main__":