mirror of
https://github.com/isledecomp/isle.git
synced 2026-01-28 10:41:15 +00:00
Add vtable comparison to reccmp
This commit is contained in:
parent
99917ca765
commit
eeb91b35e9
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import difflib
|
import difflib
|
||||||
|
import struct
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
from isledecomp.cvdump.demangler import demangle_string_const
|
from isledecomp.cvdump.demangler import demangle_string_const
|
||||||
@ -18,6 +19,7 @@
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DiffReport:
|
class DiffReport:
|
||||||
|
match_type: SymbolType
|
||||||
orig_addr: int
|
orig_addr: int
|
||||||
recomp_addr: int
|
recomp_addr: int
|
||||||
name: str
|
name: str
|
||||||
@ -214,17 +216,11 @@ def _match_thunks(self):
|
|||||||
# function in the first place.
|
# function in the first place.
|
||||||
self._db.skip_compare(thunk_from_orig)
|
self._db.skip_compare(thunk_from_orig)
|
||||||
|
|
||||||
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
|
|
||||||
"""i.e. verbose mode for reccmp"""
|
|
||||||
return self._db.get_one_function(addr)
|
|
||||||
|
|
||||||
def get_functions(self) -> List[MatchInfo]:
|
|
||||||
return self._db.get_matches(SymbolType.FUNCTION)
|
|
||||||
|
|
||||||
def _compare_function(self, match: MatchInfo) -> DiffReport:
|
def _compare_function(self, match: MatchInfo) -> DiffReport:
|
||||||
if match.size == 0:
|
if match.size == 0:
|
||||||
# Report a failed match to make the user aware of the empty function.
|
# Report a failed match to make the user aware of the empty function.
|
||||||
return DiffReport(
|
return DiffReport(
|
||||||
|
match_type=SymbolType.FUNCTION,
|
||||||
orig_addr=match.orig_addr,
|
orig_addr=match.orig_addr,
|
||||||
recomp_addr=match.recomp_addr,
|
recomp_addr=match.recomp_addr,
|
||||||
name=match.name,
|
name=match.name,
|
||||||
@ -281,6 +277,7 @@ def recomp_lookup(addr: int) -> Optional[str]:
|
|||||||
unified_diff = []
|
unified_diff = []
|
||||||
|
|
||||||
return DiffReport(
|
return DiffReport(
|
||||||
|
match_type=SymbolType.FUNCTION,
|
||||||
orig_addr=match.orig_addr,
|
orig_addr=match.orig_addr,
|
||||||
recomp_addr=match.recomp_addr,
|
recomp_addr=match.recomp_addr,
|
||||||
name=match.name,
|
name=match.name,
|
||||||
@ -289,16 +286,121 @@ def recomp_lookup(addr: int) -> Optional[str]:
|
|||||||
is_effective_match=is_effective_match,
|
is_effective_match=is_effective_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compare_function(self, addr: int) -> Optional[DiffReport]:
|
def _compare_vtable(self, match: MatchInfo) -> DiffReport:
|
||||||
match = self.get_one_function(addr)
|
vtable_size = match.size
|
||||||
|
|
||||||
|
# The vtable size should always be a multiple of 4 because that
|
||||||
|
# is the pointer size. If it is not (for whatever reason)
|
||||||
|
# it would cause iter_unpack to blow up so let's just fix it.
|
||||||
|
if vtable_size % 4 != 0:
|
||||||
|
logger.warning(
|
||||||
|
"Vtable for class %s has irregular size %d", match.name, vtable_size
|
||||||
|
)
|
||||||
|
vtable_size = 4 * (vtable_size // 4)
|
||||||
|
|
||||||
|
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
|
||||||
|
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
|
||||||
|
|
||||||
|
raw_addrs = zip(
|
||||||
|
[t for (t,) in struct.iter_unpack("<L", orig_table)],
|
||||||
|
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def match_text(
|
||||||
|
i: int, m: Optional[MatchInfo], raw_addr: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""Format the function reference at this vtable index as text.
|
||||||
|
If we have not identified this function, we have the option to
|
||||||
|
display the raw address. This is only worth doing for the original addr
|
||||||
|
because we should always be able to identify the recomp function.
|
||||||
|
If the original function is missing then this probably means that the class
|
||||||
|
should override the given function from the superclass, but we have not
|
||||||
|
implemented this yet.
|
||||||
|
"""
|
||||||
|
index = f"vtable0x{i*4:02x}"
|
||||||
|
|
||||||
|
if m is not None:
|
||||||
|
orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig"
|
||||||
|
recomp = (
|
||||||
|
hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp"
|
||||||
|
)
|
||||||
|
return f"{index:>12} : ({orig:10} / {recomp:10}) : {m.name}"
|
||||||
|
|
||||||
|
if raw_addr is not None:
|
||||||
|
return f"{index:>12} : 0x{raw_addr:x} from orig not annotated."
|
||||||
|
|
||||||
|
return f"{index:>12} : (no match)"
|
||||||
|
|
||||||
|
orig_text = []
|
||||||
|
recomp_text = []
|
||||||
|
ratio = 0
|
||||||
|
n_entries = 0
|
||||||
|
|
||||||
|
# Now compare each pointer from the two vtables.
|
||||||
|
for i, (raw_orig, raw_recomp) in enumerate(raw_addrs):
|
||||||
|
orig = self._db.get_by_orig(raw_orig)
|
||||||
|
recomp = self._db.get_by_recomp(raw_recomp)
|
||||||
|
|
||||||
|
if (
|
||||||
|
orig is not None
|
||||||
|
and recomp is not None
|
||||||
|
and orig.recomp_addr == recomp.recomp_addr
|
||||||
|
):
|
||||||
|
ratio += 1
|
||||||
|
|
||||||
|
n_entries += 1
|
||||||
|
orig_text.append(match_text(i, orig, raw_orig))
|
||||||
|
recomp_text.append(match_text(i, recomp))
|
||||||
|
|
||||||
|
ratio = ratio / float(n_entries) if n_entries > 0 else 0
|
||||||
|
|
||||||
|
# n=100: Show the entire table if there is a diff to display.
|
||||||
|
# Otherwise it would be confusing if the table got cut off.
|
||||||
|
unified_diff = difflib.unified_diff(orig_text, recomp_text, n=100)
|
||||||
|
|
||||||
|
return DiffReport(
|
||||||
|
match_type=SymbolType.VTABLE,
|
||||||
|
orig_addr=match.orig_addr,
|
||||||
|
recomp_addr=match.recomp_addr,
|
||||||
|
name=f"{match.name}::`vftable'",
|
||||||
|
udiff=unified_diff,
|
||||||
|
ratio=ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
|
||||||
|
"""Router for comparison type"""
|
||||||
|
if match.compare_type == SymbolType.FUNCTION:
|
||||||
|
return self._compare_function(match)
|
||||||
|
|
||||||
|
if match.compare_type == SymbolType.VTABLE:
|
||||||
|
return self._compare_vtable(match)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
## Public API
|
||||||
|
|
||||||
|
def get_functions(self) -> List[MatchInfo]:
|
||||||
|
return self._db.get_matches_by_type(SymbolType.FUNCTION)
|
||||||
|
|
||||||
|
def get_vtables(self) -> List[MatchInfo]:
|
||||||
|
return self._db.get_matches_by_type(SymbolType.VTABLE)
|
||||||
|
|
||||||
|
def compare_address(self, addr: int) -> Optional[DiffReport]:
|
||||||
|
match = self._db.get_one_match(addr)
|
||||||
if match is None:
|
if match is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._compare_function(match)
|
return self._compare_match(match)
|
||||||
|
|
||||||
|
def compare_all(self) -> Iterable[DiffReport]:
|
||||||
|
for match in self._db.get_matches():
|
||||||
|
diff = self._compare_match(match)
|
||||||
|
if diff is not None:
|
||||||
|
yield diff
|
||||||
|
|
||||||
def compare_functions(self) -> Iterable[DiffReport]:
|
def compare_functions(self) -> Iterable[DiffReport]:
|
||||||
for match in self.get_functions():
|
for match in self.get_functions():
|
||||||
yield self._compare_function(match)
|
yield self._compare_match(match)
|
||||||
|
|
||||||
def compare_variables(self):
|
def compare_variables(self):
|
||||||
pass
|
pass
|
||||||
@ -309,5 +411,6 @@ def compare_pointers(self):
|
|||||||
def compare_strings(self):
|
def compare_strings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def compare_vtables(self):
|
def compare_vtables(self) -> Iterable[DiffReport]:
|
||||||
pass
|
for match in self.get_vtables():
|
||||||
|
yield self._compare_match(match)
|
||||||
|
|||||||
@ -82,17 +82,29 @@ def get_unmatched_strings(self) -> List[str]:
|
|||||||
|
|
||||||
return [string for (string,) in cur.fetchall()]
|
return [string for (string,) in cur.fetchall()]
|
||||||
|
|
||||||
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
|
def get_matches(self) -> Optional[MatchInfo]:
|
||||||
cur = self._db.execute(
|
cur = self._db.execute(
|
||||||
"""SELECT compare_type, orig_addr, recomp_addr, name, size
|
"""SELECT compare_type, orig_addr, recomp_addr, name, size
|
||||||
FROM `symbols`
|
FROM `symbols`
|
||||||
WHERE compare_type = ?
|
WHERE orig_addr IS NOT NULL
|
||||||
AND orig_addr = ?
|
|
||||||
AND recomp_addr IS NOT NULL
|
AND recomp_addr IS NOT NULL
|
||||||
AND should_skip IS FALSE
|
AND should_skip IS FALSE
|
||||||
ORDER BY orig_addr
|
ORDER BY orig_addr
|
||||||
""",
|
""",
|
||||||
(SymbolType.FUNCTION.value, addr),
|
)
|
||||||
|
cur.row_factory = matchinfo_factory
|
||||||
|
|
||||||
|
return cur.fetchall()
|
||||||
|
|
||||||
|
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
|
||||||
|
cur = self._db.execute(
|
||||||
|
"""SELECT compare_type, orig_addr, recomp_addr, name, size
|
||||||
|
FROM `symbols`
|
||||||
|
WHERE orig_addr = ?
|
||||||
|
AND recomp_addr IS NOT NULL
|
||||||
|
AND should_skip IS FALSE
|
||||||
|
""",
|
||||||
|
(addr,),
|
||||||
)
|
)
|
||||||
cur.row_factory = matchinfo_factory
|
cur.row_factory = matchinfo_factory
|
||||||
return cur.fetchone()
|
return cur.fetchone()
|
||||||
@ -119,7 +131,7 @@ def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
|
|||||||
cur.row_factory = matchinfo_factory
|
cur.row_factory = matchinfo_factory
|
||||||
return cur.fetchone()
|
return cur.fetchone()
|
||||||
|
|
||||||
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
|
def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
|
||||||
cur = self._db.execute(
|
cur = self._db.execute(
|
||||||
"""SELECT compare_type, orig_addr, recomp_addr, name, size
|
"""SELECT compare_type, orig_addr, recomp_addr, name, size
|
||||||
FROM `symbols`
|
FROM `symbols`
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
print_diff,
|
print_diff,
|
||||||
)
|
)
|
||||||
from isledecomp.compare import Compare as IsleCompare
|
from isledecomp.compare import Compare as IsleCompare
|
||||||
|
from isledecomp.types import SymbolType
|
||||||
from pystache import Renderer
|
from pystache import Renderer
|
||||||
import colorama
|
import colorama
|
||||||
|
|
||||||
@ -225,9 +226,9 @@ def main():
|
|||||||
### Compare one or none.
|
### Compare one or none.
|
||||||
|
|
||||||
if args.verbose is not None:
|
if args.verbose is not None:
|
||||||
match = isle_compare.compare_function(args.verbose)
|
match = isle_compare.compare_address(args.verbose)
|
||||||
if match is None:
|
if match is None:
|
||||||
print(f"Failed to find the function with address 0x{args.verbose:x}")
|
print(f"Failed to find a match at address 0x{args.verbose:x}")
|
||||||
return
|
return
|
||||||
|
|
||||||
print_match_verbose(
|
print_match_verbose(
|
||||||
@ -242,14 +243,15 @@ def main():
|
|||||||
total_effective_accuracy = 0
|
total_effective_accuracy = 0
|
||||||
htmlinsert = []
|
htmlinsert = []
|
||||||
|
|
||||||
for match in isle_compare.compare_functions():
|
for match in isle_compare.compare_all():
|
||||||
print_match_oneline(
|
print_match_oneline(
|
||||||
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
|
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
|
||||||
)
|
)
|
||||||
|
|
||||||
function_count += 1
|
if match.match_type == SymbolType.FUNCTION:
|
||||||
total_accuracy += match.ratio
|
function_count += 1
|
||||||
total_effective_accuracy += match.effective_ratio
|
total_accuracy += match.ratio
|
||||||
|
total_effective_accuracy += match.effective_ratio
|
||||||
|
|
||||||
# If html, record the diffs to an HTML file
|
# If html, record the diffs to an HTML file
|
||||||
if args.html is not None:
|
if args.html is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user