Add vtable comparison to reccmp

This commit is contained in:
disinvite 2024-01-17 18:33:18 -05:00
parent 99917ca765
commit eeb91b35e9
3 changed files with 141 additions and 24 deletions

View File

@ -1,6 +1,7 @@
import os import os
import logging import logging
import difflib import difflib
import struct
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from isledecomp.cvdump.demangler import demangle_string_const from isledecomp.cvdump.demangler import demangle_string_const
@ -18,6 +19,7 @@
@dataclass @dataclass
class DiffReport: class DiffReport:
match_type: SymbolType
orig_addr: int orig_addr: int
recomp_addr: int recomp_addr: int
name: str name: str
@ -214,17 +216,11 @@ def _match_thunks(self):
# function in the first place. # function in the first place.
self._db.skip_compare(thunk_from_orig) self._db.skip_compare(thunk_from_orig)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION)
def _compare_function(self, match: MatchInfo) -> DiffReport: def _compare_function(self, match: MatchInfo) -> DiffReport:
if match.size == 0: if match.size == 0:
# Report a failed match to make the user aware of the empty function. # Report a failed match to make the user aware of the empty function.
return DiffReport( return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr, orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr, recomp_addr=match.recomp_addr,
name=match.name, name=match.name,
@ -281,6 +277,7 @@ def recomp_lookup(addr: int) -> Optional[str]:
unified_diff = [] unified_diff = []
return DiffReport( return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr, orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr, recomp_addr=match.recomp_addr,
name=match.name, name=match.name,
@ -289,16 +286,121 @@ def recomp_lookup(addr: int) -> Optional[str]:
is_effective_match=is_effective_match, is_effective_match=is_effective_match,
) )
def compare_function(self, addr: int) -> Optional[DiffReport]: def _compare_vtable(self, match: MatchInfo) -> DiffReport:
match = self.get_one_function(addr) vtable_size = match.size
# The vtable size should always be a multiple of 4 because that
# is the pointer size. If it is not (for whatever reason)
# it would cause iter_unpack to blow up so let's just fix it.
if vtable_size % 4 != 0:
logger.warning(
"Vtable for class %s has irregular size %d", match.name, vtable_size
)
vtable_size = 4 * (vtable_size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
def match_text(
i: int, m: Optional[MatchInfo], raw_addr: Optional[int] = None
) -> str:
"""Format the function reference at this vtable index as text.
If we have not identified this function, we have the option to
display the raw address. This is only worth doing for the original addr
because we should always be able to identify the recomp function.
If the original function is missing then this probably means that the class
should override the given function from the superclass, but we have not
implemented this yet.
"""
index = f"vtable0x{i*4:02x}"
if m is not None:
orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig"
recomp = (
hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp"
)
return f"{index:>12} : ({orig:10} / {recomp:10}) : {m.name}"
if raw_addr is not None:
return f"{index:>12} : 0x{raw_addr:x} from orig not annotated."
return f"{index:>12} : (no match)"
orig_text = []
recomp_text = []
ratio = 0
n_entries = 0
# Now compare each pointer from the two vtables.
for i, (raw_orig, raw_recomp) in enumerate(raw_addrs):
orig = self._db.get_by_orig(raw_orig)
recomp = self._db.get_by_recomp(raw_recomp)
if (
orig is not None
and recomp is not None
and orig.recomp_addr == recomp.recomp_addr
):
ratio += 1
n_entries += 1
orig_text.append(match_text(i, orig, raw_orig))
recomp_text.append(match_text(i, recomp))
ratio = ratio / float(n_entries) if n_entries > 0 else 0
# n=100: Show the entire table if there is a diff to display.
# Otherwise it would be confusing if the table got cut off.
unified_diff = difflib.unified_diff(orig_text, recomp_text, n=100)
return DiffReport(
match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=f"{match.name}::`vftable'",
udiff=unified_diff,
ratio=ratio,
)
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
"""Router for comparison type"""
if match.compare_type == SymbolType.FUNCTION:
return self._compare_function(match)
if match.compare_type == SymbolType.VTABLE:
return self._compare_vtable(match)
return None
## Public API
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.FUNCTION)
def get_vtables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.VTABLE)
def compare_address(self, addr: int) -> Optional[DiffReport]:
match = self._db.get_one_match(addr)
if match is None: if match is None:
return None return None
return self._compare_function(match) return self._compare_match(match)
def compare_all(self) -> Iterable[DiffReport]:
for match in self._db.get_matches():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_functions(self) -> Iterable[DiffReport]: def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions(): for match in self.get_functions():
yield self._compare_function(match) yield self._compare_match(match)
def compare_variables(self): def compare_variables(self):
pass pass
@ -309,5 +411,6 @@ def compare_pointers(self):
def compare_strings(self): def compare_strings(self):
pass pass
def compare_vtables(self): def compare_vtables(self) -> Iterable[DiffReport]:
pass for match in self.get_vtables():
yield self._compare_match(match)

View File

@ -82,17 +82,29 @@ def get_unmatched_strings(self) -> List[str]:
return [string for (string,) in cur.fetchall()] return [string for (string,) in cur.fetchall()]
def get_one_function(self, addr: int) -> Optional[MatchInfo]: def get_matches(self) -> Optional[MatchInfo]:
cur = self._db.execute( cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size """SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols` FROM `symbols`
WHERE compare_type = ? WHERE orig_addr IS NOT NULL
AND orig_addr = ?
AND recomp_addr IS NOT NULL AND recomp_addr IS NOT NULL
AND should_skip IS FALSE AND should_skip IS FALSE
ORDER BY orig_addr ORDER BY orig_addr
""", """,
(SymbolType.FUNCTION.value, addr), )
cur.row_factory = matchinfo_factory
return cur.fetchall()
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
WHERE orig_addr = ?
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
""",
(addr,),
) )
cur.row_factory = matchinfo_factory cur.row_factory = matchinfo_factory
return cur.fetchone() return cur.fetchone()
@ -119,7 +131,7 @@ def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory cur.row_factory = matchinfo_factory
return cur.fetchone() return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]: def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute( cur = self._db.execute(
"""SELECT compare_type, orig_addr, recomp_addr, name, size """SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols` FROM `symbols`

View File

@ -12,6 +12,7 @@
print_diff, print_diff,
) )
from isledecomp.compare import Compare as IsleCompare from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType
from pystache import Renderer from pystache import Renderer
import colorama import colorama
@ -225,9 +226,9 @@ def main():
### Compare one or none. ### Compare one or none.
if args.verbose is not None: if args.verbose is not None:
match = isle_compare.compare_function(args.verbose) match = isle_compare.compare_address(args.verbose)
if match is None: if match is None:
print(f"Failed to find the function with address 0x{args.verbose:x}") print(f"Failed to find a match at address 0x{args.verbose:x}")
return return
print_match_verbose( print_match_verbose(
@ -242,14 +243,15 @@ def main():
total_effective_accuracy = 0 total_effective_accuracy = 0
htmlinsert = [] htmlinsert = []
for match in isle_compare.compare_functions(): for match in isle_compare.compare_all():
print_match_oneline( print_match_oneline(
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
) )
function_count += 1 if match.match_type == SymbolType.FUNCTION:
total_accuracy += match.ratio function_count += 1
total_effective_accuracy += match.effective_ratio total_accuracy += match.ratio
total_effective_accuracy += match.effective_ratio
# If html, record the diffs to an HTML file # If html, record the diffs to an HTML file
if args.html is not None: if args.html is not None: