Find imports and thunk functions

This commit is contained in:
disinvite 2024-01-14 12:22:14 -05:00
parent 5abe2d615b
commit e59afb4825
3 changed files with 138 additions and 3 deletions

View File

@ -97,6 +97,8 @@ def __init__(self, filename: str, find_str: bool = False) -> None:
self.find_str = find_str
self._potential_strings = {}
self._relocated_addrs = set()
self.imports = []
self.thunks = []
def __enter__(self):
logger.debug("Bin %s Enter", self.filename)
@ -132,6 +134,8 @@ def __enter__(self):
sect.virtual_address += self.imagebase
self._populate_relocations()
self._populate_imports()
self._populate_thunks()
# This is a (semi) expensive lookup that is not necesssary in every case.
# We can find strings in the original if we have coverage using STRING markers.
@ -238,6 +242,78 @@ def _populate_relocations(self):
(relocated_addr,) = struct.unpack("<I", self.read(addr, 4))
self._relocated_addrs.add(relocated_addr)
def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata")
def iter_image_import():
ofs = idata_ofs
while True:
# Read 5 dwords until all are zero.
image_import_descriptor = struct.unpack("<5I", self.read(ofs, 20))
ofs += 20
if all(x == 0 for x in image_import_descriptor):
break
(rva_ilt, _, __, dll_name, rva_iat) = image_import_descriptor
# Convert relative virtual addresses into absolute
yield (
self.imagebase + rva_ilt,
self.imagebase + dll_name,
self.imagebase + rva_iat,
)
image_import_descriptors = list(iter_image_import())
def iter_imports():
# ILT = Import Lookup Table
# IAT = Import Address Table
# ILT gives us the symbol name of the import.
# IAT gives the address. The compiler generated a thunk function
# that jumps to the value of this address.
for start_ilt, dll_addr, start_iat in image_import_descriptors:
dll_name = self.read_string(dll_addr).decode("ascii")
ofs_ilt = start_ilt
# Address of "__imp__*" symbols.
ofs_iat = start_iat
while True:
(lookup_addr,) = struct.unpack("<L", self.read(ofs_ilt, 4))
(import_addr,) = struct.unpack("<L", self.read(ofs_iat, 4))
if lookup_addr == 0 or import_addr == 0:
break
# Skip the "Hint" field, 2 bytes
name_ofs = lookup_addr + self.imagebase + 2
symbol_name = self.read_string(name_ofs).decode("ascii")
yield (dll_name, symbol_name, ofs_iat)
ofs_ilt += 4
ofs_iat += 4
self.imports = list(iter_imports())
def _populate_thunks(self):
"""For each imported function, we generate a thunk function. The only
instruction in the function is a jmp to the address in .idata.
Search .text to find these functions."""
text_sect = self._get_section_by_name(".text")
idata_sect = self._get_section_by_name(".idata")
start = text_sect.virtual_address
ofs = start
bs = self.read(ofs, text_sect.size_of_raw_data)
for shift in (0, 2, 4):
window = bs[shift:]
win_end = 6 * (len(window) // 6)
for i, (b0, b1, jmp_ofs) in enumerate(
struct.iter_unpack("<2BL", window[:win_end])
):
if (b0, b1) == (0xFF, 0x25) and idata_sect.contains_vaddr(jmp_ofs):
# Record the address of the jmp instruction and the destination in .idata
thunk_ofs = ofs + shift + i * 6
self.thunks.append((thunk_ofs, jmp_ofs))
def _set_section_for_vaddr(self, vaddr: int):
if self.last_section is not None and self.last_section.contains_vaddr(vaddr):
return
@ -319,6 +395,18 @@ def is_valid_vaddr(self, vaddr: int) -> bool:
return section is not None
def read_string(self, offset: int, chunk_size: int = 1000) -> Optional[bytes]:
"""Read until we find a zero byte."""
b = self.read(offset, chunk_size)
if b is None:
return None
try:
return b[: b.index(b"\x00")]
except ValueError:
# No terminator found, just return what we have
return b
def read(self, offset: int, size: int) -> Optional[bytes]:
"""Read (at most) the given number of bytes at the given virtual address.
If we return None, the given address points to uninitialized data."""

View File

@ -48,6 +48,7 @@ def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._match_thunks()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
@ -147,6 +148,46 @@ def _find_original_strings(self):
self._db.match_string(addr, string)
def _match_thunks(self):
orig_byaddr = {
addr: (dll.upper(), name) for (dll, name, addr) in self.orig_bin.imports
}
recomp_byname = {
(dll.upper(), name): addr for (dll, name, addr) in self.recomp_bin.imports
}
# Combine these two dictionaries. We don't care about imports from recomp
# not found in orig because:
# 1. They shouldn't be there
# 2. They are already identified via cvdump
orig_to_recomp = {
addr: recomp_byname.get(pair, None) for addr, pair in orig_byaddr.items()
}
# Now: we have the IAT offset in each matched up, so we need to make
# the connection between the thunk functions.
# We already have the symbol name we need from the PDB.
orig_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.orig_bin.thunks
}
recomp_thunks = {
iat_ofs: func_ofs for (func_ofs, iat_ofs) in self.recomp_bin.thunks
}
for orig, recomp in orig_to_recomp.items():
self._db.set_pair(orig, recomp, SymbolType.POINTER)
thunk_from_orig = orig_thunks.get(orig, None)
thunk_from_recomp = recomp_thunks.get(recomp, None)
if thunk_from_orig is not None and thunk_from_recomp is not None:
self._db.set_function_pair(thunk_from_orig, thunk_from_recomp)
# Don't compare thunk functions for now. The comparison isn't
# "useful" in the usual sense. We are only looking at the 6
# bytes of the jmp instruction and not the larger context of
# where this function is. Also: these will always match 100%
# because we are searching for a match to register this as a
# function in the first place.
self._db.skip_compare(thunk_from_orig)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)

View File

@ -134,14 +134,20 @@ def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
return cur.fetchall()
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
def set_pair(
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
) -> bool:
compare_value = compare_type.value if compare_type is not None else None
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
(orig, SymbolType.FUNCTION.value, recomp),
(orig, compare_value, recomp),
)
return cur.rowcount > 0
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
self.set_pair(orig, recomp, SymbolType.FUNCTION)
# TODO: Both ways required?
def skip_compare(self, orig: int):