Name substitution for reccmp asm output

This commit is contained in:
disinvite 2024-01-13 15:09:03 -05:00
parent 0edbd9dab9
commit 06160d429b
9 changed files with 721 additions and 347 deletions

View File

@ -0,0 +1,2 @@
from .parse import ParseAsm
from .swap import can_resolve_register_differences

View File

@ -0,0 +1,152 @@
"""Converts x86 machine code into text (i.e. assembly). The end goal is to
compare the code in the original and recomp binaries, using longest common
subsequence (LCS), i.e. difflib.SequenceMatcher.
The capstone library takes the raw bytes and gives us the mnemnonic
and operand(s) for each instruction. We need to "sanitize" the text further
so that virtual addresses are replaced by symbol name or a generic
placeholder string."""
import re
from typing import Callable, List, Optional, Tuple
from collections import namedtuple
from capstone import Cs, CS_ARCH_X86, CS_MODE_32
disassembler = Cs(CS_ARCH_X86, CS_MODE_32)
ptr_replace_regex = re.compile(r"ptr \[(0x[0-9a-fA-F]+)\]")
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
def from_hex(string: str) -> Optional[int]:
try:
return int(string, 16)
except ValueError:
pass
return None
class ParseAsm:
def __init__(
self,
relocate_lookup: Optional[Callable[[int], bool]] = None,
name_lookup: Optional[Callable[[int], str]] = None,
) -> None:
self.relocate_lookup = relocate_lookup
self.name_lookup = name_lookup
self.replacements = {}
self.number_placeholders = True
def reset(self):
self.replacements = {}
def is_relocated(self, addr: int) -> bool:
if callable(self.relocate_lookup):
return self.relocate_lookup(addr)
return False
def lookup(self, addr: int) -> Optional[str]:
"""Return a replacement name for this address if we find one."""
if (cached := self.replacements.get(addr, None)) is not None:
return cached
if callable(self.name_lookup):
if (name := self.name_lookup(addr)) is not None:
self.replacements[addr] = name
return name
return None
def replace(self, addr: int) -> str:
"""Same function as lookup above, but here we return a placeholder
if there is no better name to use."""
if (name := self.lookup(addr)) is not None:
return name
# The placeholder number corresponds to the number of addresses we have
# already replaced. This is so the number will be consistent across the diff
# if we can replace some symbols with actual names in recomp but not orig.
idx = len(self.replacements) + 1
placeholder = f"<OFFSET{idx}>" if self.number_placeholders else "<OFFSET>"
self.replacements[addr] = placeholder
return placeholder
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if len(inst.op_str) == 0:
# Nothing to sanitize
return (inst.mnemonic, "")
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
# Otherwise (i.e. it looks like `dword ptr [address]`) it is an
# absolute indirect that we will handle below.
# Providing the starting address of the function to capstone.disasm has
# automatically resolved relative offsets to an absolute address.
# We will have to undo this for some of the jumps or they will not match.
op_str_address = from_hex(inst.op_str)
if op_str_address is not None:
if inst.mnemonic == "call":
return (inst.mnemonic, self.replace(op_str_address))
if inst.mnemonic == "jmp":
# The unwind section contains JMPs to other functions.
# If we have a name for this address, use it. If not,
# do not create a new placeholder. We will instead
# fall through to generic jump handling below.
potential_name = self.lookup(op_str_address)
if potential_name is not None:
return (inst.mnemonic, potential_name)
if inst.mnemonic.startswith("j"):
# i.e. if this is any jump
# Show the jump offset rather than the absolute address
jump_displacement = op_str_address - (inst.address + inst.size)
return (inst.mnemonic, hex(jump_displacement))
def filter_out_ptr(match):
"""Helper for re.sub, see below"""
offset = from_hex(match.group(1))
if offset is not None:
# We assume this is always an address to replace
placeholder = self.replace(offset)
return f"ptr [{placeholder}]"
# Strict regex should ensure we can read the hex number.
# But just in case: return the string with no changes
return match.group(0)
op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str)
# Performance hack:
# Skip this step if there is nothing left to consider replacing.
if "0x" in op_str:
# Replace immediate values with name or placeholder (where appropriate)
words = op_str.split(", ")
for i, word in enumerate(words):
try:
inttest = int(word, 16)
# If this value is a virtual address, it is referenced absolutely,
# which means it must be in the relocation table.
if self.is_relocated(inttest):
words[i] = self.replace(inttest)
except ValueError:
pass
op_str = ", ".join(words)
return inst.mnemonic, op_str
def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
asm = []
for inst in disassembler.disasm_lite(data, start_addr):
# Use heuristics to disregard some differences that aren't representative
# of the accuracy of a function (e.g. global offsets)
result = self.sanitize(DisasmLiteInst(*inst))
# mnemonic + " " + op_str
asm.append(" ".join(result))
return asm

View File

@ -0,0 +1,80 @@
import re
REGISTER_LIST = set(
[
"ax",
"bp",
"bx",
"cx",
"di",
"dx",
"eax",
"ebp",
"ebx",
"ecx",
"edi",
"edx",
"esi",
"esp",
"si",
"sp",
]
)
WORDS = re.compile(r"\w+")
def get_registers(line: str):
to_replace = []
# use words regex to find all matching positions:
for match in WORDS.finditer(line):
reg = match.group(0)
if reg in REGISTER_LIST:
to_replace.append((reg, match.start()))
return to_replace
def replace_register(
lines: list[str], start_line: int, reg: str, replacement: str
) -> list[str]:
return [
line.replace(reg, replacement) if i >= start_line else line
for i, line in enumerate(lines)
]
# Is it possible to make new_asm the same as original_asm by swapping registers?
def can_resolve_register_differences(original_asm, new_asm):
# Split the ASM on spaces to get more granularity, and so
# that we don't modify the original arrays passed in.
original_asm = [part for line in original_asm for part in line.split()]
new_asm = [part for line in new_asm for part in line.split()]
# Swapping ain't gonna help if the lengths are different
if len(original_asm) != len(new_asm):
return False
# Look for the mismatching lines
for i, original_line in enumerate(original_asm):
new_line = new_asm[i]
if new_line != original_line:
# Find all the registers to replace
to_replace = get_registers(original_line)
for replace in to_replace:
(reg, reg_index) = replace
replacing_reg = new_line[reg_index : reg_index + len(reg)]
if replacing_reg in REGISTER_LIST:
if replacing_reg != reg:
# Do a three-way swap replacing in all the subsequent lines
temp_reg = "&" * len(reg)
new_asm = replace_register(new_asm, i, replacing_reg, temp_reg)
new_asm = replace_register(new_asm, i, reg, replacing_reg)
new_asm = replace_register(new_asm, i, temp_reg, reg)
else:
# No replacement to do, different code, bail out
return False
# Check if the lines are now the same
for i, original_line in enumerate(original_asm):
if new_asm[i] != original_line:
return False
return True

View File

@ -1,11 +1,14 @@
import os import os
import logging import logging
from typing import List, Optional import difflib
from dataclasses import dataclass
from typing import Iterable, List, Optional
from isledecomp.cvdump.demangler import demangle_string_const from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.parser import DecompCodebase from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences
from .db import CompareDb, MatchInfo from .db import CompareDb, MatchInfo
from .lines import LinesDb from .lines import LinesDb
@ -13,6 +16,24 @@
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
orig_addr: int
recomp_addr: int
name: str
udiff: Optional[List[str]] = None
ratio: float = 0.0
is_effective_match: bool = False
@property
def effective_ratio(self) -> float:
return 1.0 if self.is_effective_match else self.ratio
def __str__(self) -> str:
"""For debug purposes. Proper diff printing (with coloring) is in another module."""
return f"{self.name} (0x{self.orig_addr:x}) {self.ratio*100:.02f}%{'*' if self.is_effective_match else ''}"
class Compare: class Compare:
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir): def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
@ -133,8 +154,84 @@ def get_one_function(self, addr: int) -> Optional[MatchInfo]:
def get_functions(self) -> List[MatchInfo]: def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION) return self._db.get_matches(SymbolType.FUNCTION)
def compare_functions(self): def _compare_function(self, match: MatchInfo) -> DiffReport:
pass if match.size == 0:
# Report a failed match to make the user aware of the empty function.
return DiffReport(
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
)
orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
def orig_should_replace(addr: int) -> bool:
return addr > self.orig_bin.imagebase and self.orig_bin.is_relocated_addr(
addr
)
def recomp_should_replace(addr: int) -> bool:
return (
addr > self.recomp_bin.imagebase
and self.recomp_bin.is_relocated_addr(addr)
)
def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr)
if m is None:
return None
return m.match_name()
def recomp_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_recomp(addr)
if m is None:
return None
return m.match_name()
orig_parse = ParseAsm(
relocate_lookup=orig_should_replace, name_lookup=orig_lookup
)
recomp_parse = ParseAsm(
relocate_lookup=recomp_should_replace, name_lookup=recomp_lookup
)
orig_asm = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_asm = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
diff = difflib.SequenceMatcher(None, orig_asm, recomp_asm)
ratio = diff.ratio()
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
is_effective_match = can_resolve_register_differences(orig_asm, recomp_asm)
unified_diff = difflib.unified_diff(orig_asm, recomp_asm, n=10)
else:
is_effective_match = False
unified_diff = []
return DiffReport(
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
udiff=unified_diff,
ratio=ratio,
is_effective_match=is_effective_match,
)
def compare_function(self, addr: int) -> Optional[DiffReport]:
match = self.get_one_function(addr)
if match is None:
return None
return self._compare_function(match)
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
yield self._compare_function(match)
def compare_variables(self): def compare_variables(self):
pass pass

View File

@ -2,7 +2,6 @@
addresses/symbols that we want to compare between the original and recompiled binaries.""" addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3 import sqlite3
import logging import logging
from collections import namedtuple
from typing import List, Optional from typing import List, Optional
from isledecomp.types import SymbolType from isledecomp.types import SymbolType
@ -16,12 +15,35 @@
size int, size int,
should_skip int default(FALSE) should_skip int default(FALSE)
); );
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr); CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (compare_type, name);
""" """
MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name") class MatchInfo:
def __init__(
self,
ctype: Optional[int],
orig: Optional[int],
recomp: Optional[int],
name: Optional[str],
size: Optional[int],
) -> None:
self.compare_type = SymbolType(ctype) if ctype is not None else None
self.orig_addr = orig
self.recomp_addr = recomp
self.name = name
self.size = size
def match_name(self) -> str:
"""Combination of the name and compare type.
Intended for name substitution in the diff. If there is a diff,
it will be more obvious what this symbol indicates."""
if self.name is None:
return None
ctype = self.compare_type.name if self.compare_type is not None else "UNK"
return f"{self.name} ({ctype})"
def matchinfo_factory(_, row): def matchinfo_factory(_, row):
@ -61,7 +83,7 @@ def get_unmatched_strings(self) -> List[str]:
def get_one_function(self, addr: int) -> Optional[MatchInfo]: def get_one_function(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute( cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name """SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols` FROM `symbols`
WHERE compare_type = ? WHERE compare_type = ?
AND orig_addr = ? AND orig_addr = ?
@ -74,9 +96,31 @@ def get_one_function(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory cur.row_factory = matchinfo_factory
return cur.fetchone() return cur.fetchone()
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE orig_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE recomp_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]: def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute( cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name """SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols` FROM `symbols`
WHERE compare_type = ? WHERE compare_type = ?
AND orig_addr IS NOT NULL AND orig_addr IS NOT NULL

View File

@ -36,7 +36,7 @@
# e.g. `S_GDATA32: [0003:000004A4], Type: T_32PRCHAR(0470), g_set` # e.g. `S_GDATA32: [0003:000004A4], Type: T_32PRCHAR(0470), g_set`
_gdata32_regex = re.compile( _gdata32_regex = re.compile(
r"S_GDATA32: \[(?P<section>\w{4}):(?P<offset>\w{8})\], Type:\s*(?P<type>\S+), (?P<name>\S+)" r"S_GDATA32: \[(?P<section>\w{4}):(?P<offset>\w{8})\], Type:\s*(?P<type>\S+), (?P<name>.+)"
) )

View File

@ -26,17 +26,3 @@ def print_diff(udiff, plain):
def get_file_in_script_dir(fn): def get_file_in_script_dir(fn):
return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn) return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn)
class OffsetPlaceholderGenerator:
def __init__(self):
self.counter = 0
self.replacements = {}
def get(self, replace_addr):
if replace_addr in self.replacements:
return self.replacements[replace_addr]
self.counter += 1
replacement = f"<OFFSET{self.counter}>"
self.replacements[replace_addr] = replacement
return replacement

View File

@ -0,0 +1,179 @@
from typing import Optional
import pytest
from isledecomp.compare.asm.parse import DisasmLiteInst, ParseAsm
def mock_inst(mnemonic: str, op_str: str) -> DisasmLiteInst:
"""Mock up the named tuple DisasmLite from just a mnemonic and op_str.
To be used for tests on sanitize that do not require the instruction address
or size. i.e. any non-jump instruction."""
return DisasmLiteInst(0, 0, mnemonic, op_str)
identity_cases = [
("", ""),
("sti", ""),
("push", "ebx"),
("ret", ""),
("ret", "4"),
("mov", "eax, 0x1234"),
]
@pytest.mark.parametrize("mnemonic, op_str", identity_cases)
def test_identity(mnemonic, op_str):
"""Confirm that nothing is substituted."""
p = ParseAsm()
inst = mock_inst(mnemonic, op_str)
result = p.sanitize(inst)
assert result == (mnemonic, op_str)
ptr_replace_cases = [
("byte ptr [0x5555]", "byte ptr [<OFFSET1>]"),
("word ptr [0x5555]", "word ptr [<OFFSET1>]"),
("dword ptr [0x5555]", "dword ptr [<OFFSET1>]"),
("qword ptr [0x5555]", "qword ptr [<OFFSET1>]"),
("eax, dword ptr [0x5555]", "eax, dword ptr [<OFFSET1>]"),
("dword ptr [0x5555], eax", "dword ptr [<OFFSET1>], eax"),
("dword ptr [0x5555], 0", "dword ptr [<OFFSET1>], 0"),
("dword ptr [0x5555], 8", "dword ptr [<OFFSET1>], 8"),
# Same value, assumed to be an addr in the first appearance
# because it is designated as 'ptr', but we have not provided the
# relocation table lookup method so we do not replace the second appearance.
("dword ptr [0x5555], 0x5555", "dword ptr [<OFFSET1>], 0x5555"),
]
@pytest.mark.parametrize("start, end", ptr_replace_cases)
def test_ptr_replace(start, end):
"""Anything in square brackets (with the 'ptr' prefix) will always be replaced."""
p = ParseAsm()
inst = mock_inst("", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
call_replace_cases = [
("ebx", "ebx"),
("0x1234", "<OFFSET1>"),
("dword ptr [0x1234]", "dword ptr [<OFFSET1>]"),
("dword ptr [ecx + 0x10]", "dword ptr [ecx + 0x10]"),
]
@pytest.mark.parametrize("start, end", call_replace_cases)
def test_call_replace(start, end):
"""Call with hex operand is always replaced.
Otherwise, ptr replacement rules apply, but skip `this` calls."""
p = ParseAsm()
inst = mock_inst("call", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_jump_displacement():
"""Display jump displacement (offset from end of jump instruction)
instead of destination address."""
p = ParseAsm()
inst = DisasmLiteInst(0x1000, 2, "je", "0x1000")
(_, op_str) = p.sanitize(inst)
assert op_str == "-0x2"
@pytest.mark.xfail(reason="Not implemented yet")
def test_jmp_table():
"""Should detect the characteristic jump table instruction
(for a switch statement) and use placeholder."""
p = ParseAsm()
inst = mock_inst("jmp", "dword ptr [eax*4 + 0x5555]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [eax*4 + <OFFSET1>]"
name_replace_cases = [
("byte ptr [0x5555]", "byte ptr [_substitute_]"),
("word ptr [0x5555]", "word ptr [_substitute_]"),
("dword ptr [0x5555]", "dword ptr [_substitute_]"),
("qword ptr [0x5555]", "qword ptr [_substitute_]"),
]
@pytest.mark.parametrize("start, end", name_replace_cases)
def test_name_replace(start, end):
"""Make sure the name lookup function is called if present"""
def substitute(_: int) -> str:
return "_substitute_"
p = ParseAsm(name_lookup=substitute)
inst = mock_inst("mov", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_replacement_cache():
p = ParseAsm()
inst = mock_inst("inc", "dword ptr [0x1234]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
def test_replacement_numbering():
"""If we can use the name lookup for the first address but not the second,
the second replacement should be <OFFSET2> not <OFFSET1>."""
def substitute_1234(addr: int) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x1234]"))
assert op_str == "dword ptr [_substitute_]"
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x5555]"))
assert op_str == "dword ptr [<OFFSET2>]"
def test_relocate_lookup():
"""Immediate values would be relocated if they are actually addresses.
So we can use the relocation table to check whether a given value is an
address or just some number."""
def relocate_lookup(addr: int) -> bool:
return addr == 0x1234
p = ParseAsm(relocate_lookup=relocate_lookup)
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x1234"))
assert op_str == "eax, <OFFSET1>"
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x5555"))
assert op_str == "eax, 0x5555"
def test_jump_to_function():
"""A jmp instruction can lead us directly to a function. This can be found
in the unwind section at the end of a function. However: we do not want to
assume this is the case for all jumps. Only replace the jump with a name
if we can find it using our lookup."""
def substitute_1234(addr: int) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x1234")
(_, op_str) = p.sanitize(inst)
assert op_str == "_substitute_"
# Should not replace this jump.
# 0x1000 (start addr)
# + 2 (size of jump instruction)
# + 0x5555 (displacement, the value we want)
# = 0x6557
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x6557")
(_, op_str) = p.sanitize(inst)
assert op_str == "0x5555"

View File

@ -2,167 +2,20 @@
import argparse import argparse
import base64 import base64
import difflib
import json import json
import logging import logging
import os import os
import re
from isledecomp import ( from isledecomp import (
Bin, Bin,
get_file_in_script_dir, get_file_in_script_dir,
OffsetPlaceholderGenerator,
print_diff, print_diff,
) )
from isledecomp.compare import Compare as IsleCompare from isledecomp.compare import Compare as IsleCompare
from capstone import Cs, CS_ARCH_X86, CS_MODE_32
import colorama
from pystache import Renderer from pystache import Renderer
import colorama
colorama.init()
REGISTER_LIST = set(
[
"ax",
"bp",
"bx",
"cx",
"di",
"dx",
"eax",
"ebp",
"ebx",
"ecx",
"edi",
"edx",
"esi",
"esp",
"si",
"sp",
]
)
WORDS = re.compile(r"\w+")
def sanitize(file, placeholder_generator, mnemonic, op_str):
op_str_is_number = False
try:
int(op_str, 16)
op_str_is_number = True
except ValueError:
pass
if (mnemonic in ["call", "jmp"]) and op_str_is_number:
# Filter out "calls" because the offsets we're not currently trying to
# match offsets. As long as there's a call in the right place, it's
# probably accurate.
op_str = placeholder_generator.get(int(op_str, 16))
else:
def filter_out_ptr(ptype, op_str):
try:
ptrstr = ptype + " ptr ["
start = op_str.index(ptrstr) + len(ptrstr)
end = op_str.index("]", start)
# This will throw ValueError if not hex
inttest = int(op_str[start:end], 16)
return (
op_str[0:start] + placeholder_generator.get(inttest) + op_str[end:]
)
except ValueError:
return op_str
# Filter out dword ptrs where the pointer is to an offset
op_str = filter_out_ptr("dword", op_str)
op_str = filter_out_ptr("word", op_str)
op_str = filter_out_ptr("byte", op_str)
# Use heuristics to filter out any args that look like offsets
words = op_str.split(" ")
for i, word in enumerate(words):
try:
inttest = int(word, 16)
if file.is_relocated_addr(inttest):
words[i] = placeholder_generator.get(inttest)
except ValueError:
pass
op_str = " ".join(words)
return mnemonic, op_str
def parse_asm(disassembler, file, asm_addr, size):
asm = []
data = file.read(asm_addr, size)
placeholder_generator = OffsetPlaceholderGenerator()
for i in disassembler.disasm(data, 0):
# Use heuristics to disregard some differences that aren't representative
# of the accuracy of a function (e.g. global offsets)
mnemonic, op_str = sanitize(file, placeholder_generator, i.mnemonic, i.op_str)
if op_str is None:
asm.append(mnemonic)
else:
asm.append(f"{mnemonic} {op_str}")
return asm
def get_registers(line: str):
to_replace = []
# use words regex to find all matching positions:
for match in WORDS.finditer(line):
reg = match.group(0)
if reg in REGISTER_LIST:
to_replace.append((reg, match.start()))
return to_replace
def replace_register(
lines: list[str], start_line: int, reg: str, replacement: str
) -> list[str]:
return [
line.replace(reg, replacement) if i >= start_line else line
for i, line in enumerate(lines)
]
# Is it possible to make new_asm the same as original_asm by swapping registers?
def can_resolve_register_differences(original_asm, new_asm):
# Split the ASM on spaces to get more granularity, and so
# that we don't modify the original arrays passed in.
original_asm = [part for line in original_asm for part in line.split()]
new_asm = [part for line in new_asm for part in line.split()]
# Swapping ain't gonna help if the lengths are different
if len(original_asm) != len(new_asm):
return False
# Look for the mismatching lines
for i, original_line in enumerate(original_asm):
new_line = new_asm[i]
if new_line != original_line:
# Find all the registers to replace
to_replace = get_registers(original_line)
for replace in to_replace:
(reg, reg_index) = replace
replacing_reg = new_line[reg_index : reg_index + len(reg)]
if replacing_reg in REGISTER_LIST:
if replacing_reg != reg:
# Do a three-way swap replacing in all the subsequent lines
temp_reg = "&" * len(reg)
new_asm = replace_register(new_asm, i, replacing_reg, temp_reg)
new_asm = replace_register(new_asm, i, reg, replacing_reg)
new_asm = replace_register(new_asm, i, temp_reg, reg)
else:
# No replacement to do, different code, bail out
return False
# Check if the lines are now the same
for i, original_line in enumerate(original_asm):
if new_asm[i] != original_line:
return False
return True
def gen_html(html_file, data): def gen_html(html_file, data):
@ -197,9 +50,88 @@ def gen_svg(svg_file, name_svg, icon, svg_implemented_funcs, total_funcs, raw_ac
svgfile.write(output_data) svgfile.write(output_data)
# Do the actual work def get_percent_color(value: float) -> str:
def main(): """Return colorama ANSI escape character for the given decimal value."""
# pylint: disable=too-many-locals, too-many-nested-blocks, too-many-branches, too-many-statements if value == 1.0:
return colorama.Fore.GREEN
if value > 0.8:
return colorama.Fore.YELLOW
return colorama.Fore.RED
def percent_string(
ratio: float, is_effective: bool = False, is_plain: bool = False
) -> str:
"""Helper to construct a percentage string from the given ratio.
If is_effective (i.e. effective match), indicate that with the asterisk.
If is_plain, don't use colorama ANSI codes."""
percenttext = f"{(ratio * 100):.2f}%"
effective_star = "*" if is_effective else ""
if is_plain:
return percenttext + effective_star
return "".join(
[
get_percent_color(ratio),
percenttext,
colorama.Fore.RED if is_effective else "",
effective_star,
colorama.Style.RESET_ALL,
]
)
def print_match_verbose(match, show_both_addrs: bool = False, is_plain: bool = False):
percenttext = percent_string(
match.effective_ratio, match.is_effective_match, is_plain
)
if show_both_addrs:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
if match.effective_ratio == 1.0:
ok_text = (
"OK!"
if is_plain
else (colorama.Fore.GREEN + "✨ OK! ✨" + colorama.Style.RESET_ALL)
)
if match.ratio == 1.0:
print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n")
else:
print(
f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n"
)
else:
print_diff(match.udiff, is_plain)
print(
f"\n{match.name} is only {percenttext} similar to the original, diff above"
)
def print_match_oneline(match, show_both_addrs: bool = False, is_plain: bool = False):
percenttext = percent_string(
match.effective_ratio, match.is_effective_match, is_plain
)
if show_both_addrs:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
print(f" {match.name} ({addrs}) is {percenttext} similar to the original")
def parse_args() -> argparse.Namespace:
def virtual_address(value) -> int:
"""Helper method for argparse, verbose parameter"""
return int(value, 16)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
allow_abbrev=False, allow_abbrev=False,
description="Recompilation Compare: compare an original EXE with a recompiled EXE + PDB.", description="Recompilation Compare: compare an original EXE with a recompiled EXE + PDB.",
@ -226,6 +158,7 @@ def main():
"--verbose", "--verbose",
"-v", "-v",
metavar="<offset>", metavar="<offset>",
type=virtual_address,
help="Print assembly diff for specific function (original file's offset)", help="Print assembly diff for specific function (original file's offset)",
) )
parser.add_argument( parser.add_argument(
@ -258,198 +191,99 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
if not os.path.isfile(args.pdb):
parser.error(f"Symbols PDB {args.pdb} does not exist")
if not os.path.isdir(args.decomp_dir):
parser.error(f"Source directory {args.decomp_dir} does not exist")
return args
def main():
args = parse_args()
logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s") logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s")
colorama.init() with Bin(args.original, find_str=True) as origfile, Bin(
args.recompiled
verbose = None ) as recompfile:
found_verbose_target = False if args.verbose is not None:
if args.verbose:
try:
verbose = int(args.verbose, 16)
except ValueError:
parser.error("invalid verbose argument")
html_path = args.html
plain = args.no_color
original = args.original
if not os.path.isfile(original):
parser.error(f"Original binary {original} does not exist")
recomp = args.recompiled
if not os.path.isfile(recomp):
parser.error(f"Recompiled binary {recomp} does not exist")
syms = args.pdb
if not os.path.isfile(syms):
parser.error(f"Symbols PDB {syms} does not exist")
source = args.decomp_dir
if not os.path.isdir(source):
parser.error(f"Source directory {source} does not exist")
svg = args.svg
with Bin(original, find_str=True) as origfile, Bin(recomp) as recompfile:
if verbose is not None:
# Mute logger events from compare engine # Mute logger events from compare engine
logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL) logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL) logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL)
isle_compare = IsleCompare(origfile, recompfile, syms, source) isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
print() print()
capstone_disassembler = Cs(CS_ARCH_X86, CS_MODE_32) ### Compare one or none.
if args.verbose is not None:
match = isle_compare.compare_function(args.verbose)
if match is None:
print(f"Failed to find the function with address 0x{args.verbose:x}")
return
print_match_verbose(match, is_plain=args.no_color)
return
### Compare everything.
function_count = 0 function_count = 0
total_accuracy = 0 total_accuracy = 0
total_effective_accuracy = 0 total_effective_accuracy = 0
htmlinsert = [] htmlinsert = []
matches = [] for match in isle_compare.compare_functions():
if verbose is not None: print_match_oneline(match, is_plain=args.no_color)
match = isle_compare.get_one_function(verbose)
if match is not None:
found_verbose_target = True
matches = [match]
else:
matches = isle_compare.get_functions()
for match in matches:
# The effective_ratio is the ratio when ignoring differing register
# allocation vs the ratio is the true ratio.
ratio = 0.0
effective_ratio = 0.0
if match.size:
origasm = parse_asm(
capstone_disassembler,
origfile,
match.orig_addr,
match.size,
)
recompasm = parse_asm(
capstone_disassembler,
recompfile,
match.recomp_addr,
match.size,
)
diff = difflib.SequenceMatcher(None, origasm, recompasm)
ratio = diff.ratio()
effective_ratio = ratio
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
if can_resolve_register_differences(origasm, recompasm):
effective_ratio = 1.0
else:
ratio = 0
percenttext = f"{(effective_ratio * 100):.2f}%"
if not plain:
if effective_ratio == 1.0:
percenttext = (
colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL
)
elif effective_ratio > 0.8:
percenttext = (
colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL
)
else:
percenttext = (
colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL
)
if effective_ratio == 1.0 and ratio != 1.0:
if plain:
percenttext += "*"
else:
percenttext += colorama.Fore.RED + "*" + colorama.Style.RESET_ALL
if args.print_rec_addr:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
if not verbose:
print(
f" {match.name} ({addrs}) is {percenttext} similar to the original"
)
function_count += 1 function_count += 1
total_accuracy += ratio total_accuracy += match.ratio
total_effective_accuracy += effective_ratio total_effective_accuracy += match.effective_ratio
if match.size: # If html, record the diffs to an HTML file
udiff = difflib.unified_diff(origasm, recompasm, n=10) if args.html is not None:
htmlinsert.append(
# If verbose, print the diff for that function to the output {
if verbose: "address": f"0x{match.orig_addr:x}",
if effective_ratio == 1.0: "name": match.name,
ok_text = ( "matching": match.effective_ratio,
"OK!" "diff": "\n".join(match.udiff),
if plain }
else (
colorama.Fore.GREEN
+ "✨ OK! ✨"
+ colorama.Style.RESET_ALL
)
)
if ratio == 1.0:
print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n")
else:
print(
f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n"
)
else:
print_diff(udiff, plain)
print(
f"\n{match.name} is only {percenttext} similar to the original, diff above"
)
# If html, record the diffs to an HTML file
if html_path:
htmlinsert.append(
{
"address": f"0x{match.orig_addr:x}",
"name": match.name,
"matching": effective_ratio,
"diff": "\n".join(udiff),
}
)
if html_path:
gen_html(html_path, json.dumps(htmlinsert))
if verbose:
if not found_verbose_target:
print(f"Failed to find the function with address 0x{verbose:x}")
else:
implemented_funcs = function_count
if args.total:
function_count = int(args.total)
if function_count > 0:
effective_accuracy = total_effective_accuracy / function_count * 100
actual_accuracy = total_accuracy / function_count * 100
print(
f"\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)"
) )
if svg: ## Generate files and show summary.
gen_svg(
svg, if args.html is not None:
os.path.basename(original), gen_html(args.html, json.dumps(htmlinsert))
args.svg_icon,
implemented_funcs, implemented_funcs = function_count
function_count,
total_effective_accuracy, if args.total:
) function_count = int(args.total)
if function_count > 0:
effective_accuracy = total_effective_accuracy / function_count * 100
actual_accuracy = total_accuracy / function_count * 100
print(
f"\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)"
)
if args.svg is not None:
gen_svg(
args.svg,
os.path.basename(args.original),
args.svg_icon,
implemented_funcs,
function_count,
total_effective_accuracy,
)
if __name__ == "__main__": if __name__ == "__main__":