From eeb91b35e951ba5d530bc9224b3f1fdfb3167e07 Mon Sep 17 00:00:00 2001 From: disinvite Date: Wed, 17 Jan 2024 18:33:18 -0500 Subject: [PATCH] Add vtable comparison to reccmp --- tools/isledecomp/isledecomp/compare/core.py | 129 ++++++++++++++++++-- tools/isledecomp/isledecomp/compare/db.py | 22 +++- tools/reccmp/reccmp.py | 14 ++- 3 files changed, 141 insertions(+), 24 deletions(-) diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index c8ae7809..58f99d32 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -1,6 +1,7 @@ import os import logging import difflib +import struct from dataclasses import dataclass from typing import Iterable, List, Optional from isledecomp.cvdump.demangler import demangle_string_const @@ -18,6 +19,7 @@ @dataclass class DiffReport: + match_type: SymbolType orig_addr: int recomp_addr: int name: str @@ -214,17 +216,11 @@ def _match_thunks(self): # function in the first place. self._db.skip_compare(thunk_from_orig) - def get_one_function(self, addr: int) -> Optional[MatchInfo]: - """i.e. verbose mode for reccmp""" - return self._db.get_one_function(addr) - - def get_functions(self) -> List[MatchInfo]: - return self._db.get_matches(SymbolType.FUNCTION) - def _compare_function(self, match: MatchInfo) -> DiffReport: if match.size == 0: # Report a failed match to make the user aware of the empty function. return DiffReport( + match_type=SymbolType.FUNCTION, orig_addr=match.orig_addr, recomp_addr=match.recomp_addr, name=match.name, @@ -281,6 +277,7 @@ def recomp_lookup(addr: int) -> Optional[str]: unified_diff = [] return DiffReport( + match_type=SymbolType.FUNCTION, orig_addr=match.orig_addr, recomp_addr=match.recomp_addr, name=match.name, @@ -289,16 +286,121 @@ def recomp_lookup(addr: int) -> Optional[str]: is_effective_match=is_effective_match, ) - def compare_function(self, addr: int) -> Optional[DiffReport]: - match = self.get_one_function(addr) + def _compare_vtable(self, match: MatchInfo) -> DiffReport: + vtable_size = match.size + + # The vtable size should always be a multiple of 4 because that + # is the pointer size. If it is not (for whatever reason) + # it would cause iter_unpack to blow up so let's just fix it. + if vtable_size % 4 != 0: + logger.warning( + "Vtable for class %s has irregular size %d", match.name, vtable_size + ) + vtable_size = 4 * (vtable_size // 4) + + orig_table = self.orig_bin.read(match.orig_addr, vtable_size) + recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size) + + raw_addrs = zip( + [t for (t,) in struct.iter_unpack(" str: + """Format the function reference at this vtable index as text. + If we have not identified this function, we have the option to + display the raw address. This is only worth doing for the original addr + because we should always be able to identify the recomp function. + If the original function is missing then this probably means that the class + should override the given function from the superclass, but we have not + implemented this yet. + """ + index = f"vtable0x{i*4:02x}" + + if m is not None: + orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig" + recomp = ( + hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp" + ) + return f"{index:>12} : ({orig:10} / {recomp:10}) : {m.name}" + + if raw_addr is not None: + return f"{index:>12} : 0x{raw_addr:x} from orig not annotated." + + return f"{index:>12} : (no match)" + + orig_text = [] + recomp_text = [] + ratio = 0 + n_entries = 0 + + # Now compare each pointer from the two vtables. + for i, (raw_orig, raw_recomp) in enumerate(raw_addrs): + orig = self._db.get_by_orig(raw_orig) + recomp = self._db.get_by_recomp(raw_recomp) + + if ( + orig is not None + and recomp is not None + and orig.recomp_addr == recomp.recomp_addr + ): + ratio += 1 + + n_entries += 1 + orig_text.append(match_text(i, orig, raw_orig)) + recomp_text.append(match_text(i, recomp)) + + ratio = ratio / float(n_entries) if n_entries > 0 else 0 + + # n=100: Show the entire table if there is a diff to display. + # Otherwise it would be confusing if the table got cut off. + unified_diff = difflib.unified_diff(orig_text, recomp_text, n=100) + + return DiffReport( + match_type=SymbolType.VTABLE, + orig_addr=match.orig_addr, + recomp_addr=match.recomp_addr, + name=f"{match.name}::`vftable'", + udiff=unified_diff, + ratio=ratio, + ) + + def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]: + """Router for comparison type""" + if match.compare_type == SymbolType.FUNCTION: + return self._compare_function(match) + + if match.compare_type == SymbolType.VTABLE: + return self._compare_vtable(match) + + return None + + ## Public API + + def get_functions(self) -> List[MatchInfo]: + return self._db.get_matches_by_type(SymbolType.FUNCTION) + + def get_vtables(self) -> List[MatchInfo]: + return self._db.get_matches_by_type(SymbolType.VTABLE) + + def compare_address(self, addr: int) -> Optional[DiffReport]: + match = self._db.get_one_match(addr) if match is None: return None - return self._compare_function(match) + return self._compare_match(match) + + def compare_all(self) -> Iterable[DiffReport]: + for match in self._db.get_matches(): + diff = self._compare_match(match) + if diff is not None: + yield diff def compare_functions(self) -> Iterable[DiffReport]: for match in self.get_functions(): - yield self._compare_function(match) + yield self._compare_match(match) def compare_variables(self): pass @@ -309,5 +411,6 @@ def compare_pointers(self): def compare_strings(self): pass - def compare_vtables(self): - pass + def compare_vtables(self) -> Iterable[DiffReport]: + for match in self.get_vtables(): + yield self._compare_match(match) diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py index d2c2cbf9..dea2b590 100644 --- a/tools/isledecomp/isledecomp/compare/db.py +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -82,17 +82,29 @@ def get_unmatched_strings(self) -> List[str]: return [string for (string,) in cur.fetchall()] - def get_one_function(self, addr: int) -> Optional[MatchInfo]: + def get_matches(self) -> Optional[MatchInfo]: cur = self._db.execute( """SELECT compare_type, orig_addr, recomp_addr, name, size FROM `symbols` - WHERE compare_type = ? - AND orig_addr = ? + WHERE orig_addr IS NOT NULL AND recomp_addr IS NOT NULL AND should_skip IS FALSE ORDER BY orig_addr """, - (SymbolType.FUNCTION.value, addr), + ) + cur.row_factory = matchinfo_factory + + return cur.fetchall() + + def get_one_match(self, addr: int) -> Optional[MatchInfo]: + cur = self._db.execute( + """SELECT compare_type, orig_addr, recomp_addr, name, size + FROM `symbols` + WHERE orig_addr = ? + AND recomp_addr IS NOT NULL + AND should_skip IS FALSE + """, + (addr,), ) cur.row_factory = matchinfo_factory return cur.fetchone() @@ -119,7 +131,7 @@ def get_by_recomp(self, addr: int) -> Optional[MatchInfo]: cur.row_factory = matchinfo_factory return cur.fetchone() - def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]: + def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]: cur = self._db.execute( """SELECT compare_type, orig_addr, recomp_addr, name, size FROM `symbols` diff --git a/tools/reccmp/reccmp.py b/tools/reccmp/reccmp.py index 66bb0f66..69783cb2 100755 --- a/tools/reccmp/reccmp.py +++ b/tools/reccmp/reccmp.py @@ -12,6 +12,7 @@ print_diff, ) from isledecomp.compare import Compare as IsleCompare +from isledecomp.types import SymbolType from pystache import Renderer import colorama @@ -225,9 +226,9 @@ def main(): ### Compare one or none. if args.verbose is not None: - match = isle_compare.compare_function(args.verbose) + match = isle_compare.compare_address(args.verbose) if match is None: - print(f"Failed to find the function with address 0x{args.verbose:x}") + print(f"Failed to find a match at address 0x{args.verbose:x}") return print_match_verbose( @@ -242,14 +243,15 @@ def main(): total_effective_accuracy = 0 htmlinsert = [] - for match in isle_compare.compare_functions(): + for match in isle_compare.compare_all(): print_match_oneline( match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color ) - function_count += 1 - total_accuracy += match.ratio - total_effective_accuracy += match.effective_ratio + if match.match_type == SymbolType.FUNCTION: + function_count += 1 + total_accuracy += match.ratio + total_effective_accuracy += match.effective_ratio # If html, record the diffs to an HTML file if args.html is not None: