Implement struct imports [skip ci]

- This code is still in dire need of refactoring and tests
- There are only single-digit issues left, and 2600 functions can be imported
- The biggest remaining error is mismatched stacks
This commit is contained in:
jonschz 2024-05-26 21:58:19 +02:00
parent c6817527d1
commit a8f6e72b97
8 changed files with 448 additions and 59 deletions

View File

@ -15,6 +15,7 @@
import importlib import importlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging.handlers
import sys import sys
import logging import logging
from pathlib import Path from pathlib import Path
@ -44,18 +45,25 @@ def reload_module(module: str):
def setup_logging(): def setup_logging():
logging.basicConfig( logging.root.handlers.clear()
format="%(levelname)-8s %(message)s", formatter = logging.Formatter("%(levelname)-8s %(message)s")
stream=sys.stdout, # formatter = logging.Formatter("%(name)s %(levelname)-8s %(message)s") # use this to identify loggers
level=logging.INFO, stdout_handler = logging.StreamHandler(sys.stdout)
force=True, stdout_handler.setFormatter(formatter)
file_handler = logging.FileHandler(
Path(__file__).absolute().parent.joinpath("import.log"), mode="w"
) )
file_handler.setFormatter(formatter)
logging.root.setLevel(GLOBALS.loglevel)
logging.root.addHandler(stdout_handler)
logging.root.addHandler(file_handler)
logger.info("Starting...") logger.info("Starting...")
@dataclass @dataclass
class Globals: class Globals:
verbose: bool verbose: bool
loglevel: int
running_from_ghidra: bool = False running_from_ghidra: bool = False
make_changes: bool = False make_changes: bool = False
prompt_before_changes: bool = True prompt_before_changes: bool = True
@ -64,7 +72,11 @@ class Globals:
# hard-coded settings that we don't want to prompt in Ghidra every time # hard-coded settings that we don't want to prompt in Ghidra every time
GLOBALS = Globals(verbose=False) GLOBALS = Globals(
verbose=False,
# loglevel=logging.INFO,
loglevel=logging.DEBUG,
)
# Disable spurious warnings in vscode / pylance # Disable spurious warnings in vscode / pylance
@ -111,14 +123,19 @@ def add_python_path(path: str):
# We need to quote the types here because they might not exist when running without Ghidra # We need to quote the types here because they might not exist when running without Ghidra
def migrate_function_to_ghidra( def migrate_function_to_ghidra(
api: "FlatProgramAPI", match_info: "MatchInfo", signature: "FunctionSignature" api: "FlatProgramAPI",
match_info: "MatchInfo",
signature: "FunctionSignature",
type_importer: "PdbTypeImporter",
): ):
hex_original_address = f"{match_info.orig_addr:x}" hex_original_address = f"{match_info.orig_addr:x}"
# Find the Ghidra function at that address # Find the Ghidra function at that address
ghidra_address = getAddressFactory().getAddress(hex_original_address) ghidra_address = getAddressFactory().getAddress(hex_original_address)
typed_pdb_function = PdbFunctionWithGhidraObjects(api, match_info, signature) typed_pdb_function = PdbFunctionWithGhidraObjects(
api, match_info, signature, type_importer
)
if not GLOBALS.make_changes: if not GLOBALS.make_changes:
return return
@ -170,19 +187,20 @@ def migrate_function_to_ghidra(
askChoice("Continue", "Click 'OK' to continue", ["OK"], "OK") askChoice("Continue", "Click 'OK' to continue", ["OK"], "OK")
def process_functions(isle_compare: "IsleCompare"): def process_functions(extraction: "PdbExtractionForGhidraMigration"):
# try to acquire matched functions func_signatures = extraction.get_function_list()
migration = PdbExtractionForGhidraMigration(isle_compare)
func_signatures = migration.get_function_list()
if not GLOBALS.running_from_ghidra: if not GLOBALS.running_from_ghidra:
logger.info("Completed the dry run outside Ghidra.") logger.info("Completed the dry run outside Ghidra.")
return return
fpapi = FlatProgramAPI(currentProgram()) api = FlatProgramAPI(currentProgram())
# TODO: Implement a "no changes" mode
type_importer = PdbTypeImporter(api, extraction)
for match_info, signature in func_signatures: for match_info, signature in func_signatures:
try: try:
migrate_function_to_ghidra(fpapi, match_info, signature) migrate_function_to_ghidra(api, match_info, signature, type_importer)
GLOBALS.statistics.successes += 1 GLOBALS.statistics.successes += 1
except Lego1Exception as e: except Lego1Exception as e:
log_and_track_failure(e) log_and_track_failure(e)
@ -216,8 +234,11 @@ def main():
pdb_path = build_path.joinpath("LEGO1.pdb") pdb_path = build_path.joinpath("LEGO1.pdb")
if not GLOBALS.verbose: if not GLOBALS.verbose:
logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL) logging.getLogger("isledecomp.bin").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL) logging.getLogger("isledecomp.compare.core").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.db").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.WARNING)
logging.getLogger("isledecomp.cvdump.symbols").setLevel(logging.WARNING)
logger.info("Starting comparison") logger.info("Starting comparison")
with Bin(str(origfile_path), find_str=True) as origfile, Bin( with Bin(str(origfile_path), find_str=True) as origfile, Bin(
@ -227,8 +248,10 @@ def main():
logger.info("Comparison complete.") logger.info("Comparison complete.")
# try to acquire matched functions
migration = PdbExtractionForGhidraMigration(isle_compare)
try: try:
process_functions(isle_compare) process_functions(migration)
finally: finally:
if GLOBALS.running_from_ghidra: if GLOBALS.running_from_ghidra:
GLOBALS.statistics.log() GLOBALS.statistics.log()
@ -265,8 +288,13 @@ def main():
) )
if GLOBALS.running_from_ghidra: if GLOBALS.running_from_ghidra:
reload_module("lego_util.pdb_to_ghidra") reload_module("lego_util.ghidra_helper")
from lego_util.pdb_to_ghidra import PdbFunctionWithGhidraObjects
reload_module("lego_util.function_importer")
from lego_util.function_importer import PdbFunctionWithGhidraObjects
reload_module("lego_util.type_importer")
from lego_util.type_importer import PdbTypeImporter
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,13 @@
class Lego1Exception(Exception): class Lego1Exception(Exception):
pass """
Our own base class for exceptions.
Makes it easier to distinguish expected and unexpected errors.
"""
class TypeNotFoundError(Lego1Exception):
def __str__(self):
return f"Type not found in PDB: {self.args[0]}"
class TypeNotFoundInGhidraError(Lego1Exception): class TypeNotFoundInGhidraError(Lego1Exception):
@ -7,6 +15,11 @@ def __str__(self):
return f"Type not found in Ghidra: {self.args[0]}" return f"Type not found in Ghidra: {self.args[0]}"
class TypeNotImplementedError(Lego1Exception):
def __str__(self):
return f"Import not implemented for type: {self.args[0]}"
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception): class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy: list[str]): def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy) super().__init__(namespaceHierachy)

View File

@ -20,9 +20,11 @@
) )
from lego_util.ghidra_helper import ( from lego_util.ghidra_helper import (
get_ghidra_namespace, get_ghidra_namespace,
get_ghidra_type, sanitize_class_name,
) )
from lego_util.exceptions import StackOffsetMismatchError from lego_util.exceptions import StackOffsetMismatchError
from lego_util.type_importer import PdbTypeImporter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,26 +35,36 @@ class PdbFunctionWithGhidraObjects:
def __init__( def __init__(
self, self,
fpapi: FlatProgramAPI, api: FlatProgramAPI,
match_info: MatchInfo, match_info: MatchInfo,
signature: FunctionSignature, signature: FunctionSignature,
type_importer: "PdbTypeImporter",
): ):
self.api = fpapi self.api = api
self.match_info = match_info self.match_info = match_info
self.signature = signature self.signature = signature
self.type_importer = type_importer
if signature.class_type is not None:
# Import the base class so the namespace exists
self.type_importer.pdb_to_ghidra_type(signature.class_type)
assert match_info.name is not None assert match_info.name is not None
colon_split = match_info.name.split("::")
colon_split = sanitize_class_name(match_info.name).split("::")
self.name = colon_split.pop() self.name = colon_split.pop()
namespace_hierachy = colon_split namespace_hierachy = colon_split
self.namespace = get_ghidra_namespace(fpapi, namespace_hierachy) self.namespace = get_ghidra_namespace(api, namespace_hierachy)
self.return_type = get_ghidra_type(fpapi, signature.return_type) self.return_type = type_importer.pdb_to_ghidra_type(
signature.return_type
)
self.arguments = [ self.arguments = [
ParameterImpl( ParameterImpl(
f"param{index}", f"param{index}",
get_ghidra_type(fpapi, type_name), # get_ghidra_type(api, type_name),
fpapi.getCurrentProgram(), type_importer.pdb_to_ghidra_type(type_name),
api.getCurrentProgram(),
) )
for (index, type_name) in enumerate(signature.arglist) for (index, type_name) in enumerate(signature.arglist)
] ]
@ -200,7 +212,13 @@ def _rename_stack_parameter(self, param: Parameter):
f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}" f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}"
) )
if param.getDataType() != get_ghidra_type(self.api, match.data_type): if match.data_type == "T_NOTYPE(0000)":
logger.warning("Skipping stack parameter of type NOTYPE")
return
if param.getDataType() != self.type_importer.pdb_to_ghidra_type(
match.data_type
):
logger.error( logger.error(
"Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match "Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match
) )

View File

@ -16,6 +16,8 @@
from ghidra.program.model.data import DataType from ghidra.program.model.data import DataType
from ghidra.program.model.symbol import Namespace from ghidra.program.model.symbol import Namespace
logger = logging.getLogger(__name__)
def get_ghidra_type(api: FlatProgramAPI, type_name: str): def get_ghidra_type(api: FlatProgramAPI, type_name: str):
""" """
@ -44,14 +46,21 @@ def get_ghidra_type(api: FlatProgramAPI, type_name: str):
raise MultipleTypesFoundInGhidraError(type_name, result) raise MultipleTypesFoundInGhidraError(type_name, result)
def add_pointer_type(api: FlatProgramAPI, pointee: DataType): def add_pointer_type(api: FlatProgramAPI, pointee: DataType) -> DataType:
data_type = PointerDataType(pointee) new_data_type = PointerDataType(pointee)
data_type.setCategoryPath(pointee.getCategoryPath()) new_data_type.setCategoryPath(pointee.getCategoryPath())
api.getCurrentProgram().getDataTypeManager().addDataType( result_data_type = (
data_type, DataTypeConflictHandler.KEEP_HANDLER api.getCurrentProgram()
.getDataTypeManager()
.addDataType(new_data_type, DataTypeConflictHandler.KEEP_HANDLER)
) )
logging.info("Created new pointer type %s", data_type) if result_data_type is not new_data_type:
return data_type logger.debug(
"New pointer replaced by existing one. Fresh pointer: %s (class: %s)",
result_data_type,
result_data_type.__class__,
)
return result_data_type
def get_ghidra_namespace( def get_ghidra_namespace(
@ -63,3 +72,38 @@ def get_ghidra_namespace(
if namespace is None: if namespace is None:
raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy) raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy)
return namespace return namespace
def create_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
namespace = api.createNamespace(namespace, part)
return namespace
def sanitize_class_name(name: str) -> str:
"""
Takes a full class or function name and replaces characters not accepted by Ghidra.
Applies mostly to templates.
"""
if "<" in name:
new_class_name = (
"_template_" +
name
.replace("<", "[")
.replace(">", "]")
.replace("*", "#")
.replace(" ", "")
)
logger.warning(
"Changing possible template class name from '%s' to '%s'",
name,
new_class_name,
)
return new_class_name
return name

View File

@ -8,13 +8,11 @@
from isledecomp.compare import Compare as IsleCompare from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.db import MatchInfo from isledecomp.compare.db import MatchInfo
from lego_util.exceptions import TypeNotFoundError
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
class TypeNotFoundError(Exception):
pass
@dataclass @dataclass
class CppStackOrRegisterSymbol: class CppStackOrRegisterSymbol:
name: str name: str
@ -38,7 +36,7 @@ class FunctionSignature:
call_type: str call_type: str
arglist: list[str] arglist: list[str]
return_type: str return_type: str
class_type: Optional[dict[str, Any]] class_type: Optional[str]
stack_symbols: list[CppStackOrRegisterSymbol] stack_symbols: list[CppStackOrRegisterSymbol]
@ -46,7 +44,7 @@ class PdbExtractionForGhidraMigration:
def __init__(self, compare: IsleCompare): def __init__(self, compare: IsleCompare):
self.compare = compare self.compare = compare
_scalar_type_regex = re.compile(r"t_(?P<typename>\w+)(?:\((?P<type_id>\d+)\))?") scalar_type_regex = re.compile(r"t_(?P<typename>\w+)(?:\((?P<type_id>\d+)\))?")
_scalar_type_map = { _scalar_type_map = {
"rchar": "char", "rchar": "char",
@ -62,10 +60,11 @@ def __init__(self, compare: IsleCompare):
"STD Near": "__stdcall", "STD Near": "__stdcall",
} }
def scalar_type_to_cpp(self, scalar_type: str) -> str: @classmethod
def scalar_type_to_cpp(cls, scalar_type: str) -> str:
if scalar_type.startswith("32p"): if scalar_type.startswith("32p"):
return f"{self.scalar_type_to_cpp(scalar_type[3:])} *" return f"{cls.scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type) return cls._scalar_type_map.get(scalar_type, scalar_type)
def lookup_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]: def lookup_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return ( return (
@ -74,11 +73,12 @@ def lookup_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
else self.compare.cv.types.keys.get(type_name.lower()) else self.compare.cv.types.keys.get(type_name.lower())
) )
# TODO: This is mostly legacy code now, we may be able to remove it
def type_to_cpp_type_name(self, type_name: str) -> str: def type_to_cpp_type_name(self, type_name: str) -> str:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
type_lower = type_name.lower() type_lower = type_name.lower()
if type_lower.startswith("t_"): if type_lower.startswith("t_"):
if (match := self._scalar_type_regex.match(type_lower)) is None: if (match := self.scalar_type_regex.match(type_lower)) is None:
raise TypeNotFoundError(f"Type has unexpected format: {type_name}") raise TypeNotFoundError(f"Type has unexpected format: {type_name}")
return self.scalar_type_to_cpp(match.group("typename")) return self.scalar_type_to_cpp(match.group("typename"))
@ -131,16 +131,12 @@ def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
) )
return None return None
return_type = self.type_to_cpp_type_name(function_type["return_type"]) class_type = function_type.get("class_type")
class_type = self.lookup_type(function_type.get("class_type"))
arg_list_type = self.lookup_type(function_type.get("arg_list_type")) arg_list_type = self.lookup_type(function_type.get("arg_list_type"))
assert arg_list_type is not None assert arg_list_type is not None
arg_list_pdb_types = arg_list_type.get("args", []) arg_list_pdb_types = arg_list_type.get("args", [])
assert arg_list_type["argcount"] == len(arg_list_pdb_types) assert arg_list_type["argcount"] == len(arg_list_pdb_types)
arglist = [
self.type_to_cpp_type_name(argtype) for argtype in arg_list_pdb_types
]
stack_symbols: list[CppStackOrRegisterSymbol] = [] stack_symbols: list[CppStackOrRegisterSymbol] = []
for symbol in fn.stack_symbols: for symbol in fn.stack_symbols:
@ -157,7 +153,7 @@ def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
stack_symbols.append( stack_symbols.append(
CppStackSymbol( CppStackSymbol(
symbol.name, symbol.name,
self.type_to_cpp_type_name(symbol.data_type), symbol.data_type,
stack_offset, stack_offset,
) )
) )
@ -166,8 +162,8 @@ def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
return FunctionSignature( return FunctionSignature(
call_type=call_type, call_type=call_type,
arglist=arglist, arglist=arg_list_pdb_types,
return_type=return_type, return_type=function_type["return_type"],
class_type=class_type, class_type=class_type,
stack_symbols=stack_symbols, stack_symbols=stack_symbols,
) )
@ -175,7 +171,7 @@ def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]: def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]:
handled = ( handled = (
self.handle_matched_function(match) self.handle_matched_function(match)
for match in self.compare._db.get_matches_by_type(SymbolType.FUNCTION) for match in self.compare.db.get_matches_by_type(SymbolType.FUNCTION)
) )
return [signature for signature in handled if signature is not None] return [signature for signature in handled if signature is not None]
@ -183,7 +179,7 @@ def handle_matched_function(
self, match_info: MatchInfo self, match_info: MatchInfo
) -> Optional[tuple[MatchInfo, FunctionSignature]]: ) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None assert match_info.orig_addr is not None
match_options = self.compare._db.get_match_options(match_info.orig_addr) match_options = self.compare.db.get_match_options(match_info.orig_addr)
assert match_options is not None assert match_options is not None
if match_options.get("skip", False) or match_options.get("stub", False): if match_options.get("skip", False) or match_options.get("stub", False):
return None return None

View File

@ -0,0 +1,260 @@
from typing import Any
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundError,
TypeNotFoundInGhidraError,
TypeNotImplementedError,
)
from lego_util.ghidra_helper import (
add_pointer_type,
create_ghidra_namespace,
get_ghidra_namespace,
get_ghidra_type,
sanitize_class_name,
)
from lego_util.pdb_extraction import PdbExtractionForGhidraMigration
from lego_util.function_importer import logger
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import (
ArrayDataType,
CategoryPath,
DataType,
DataTypeConflictHandler,
StructureDataType,
StructureInternal,
)
from ghidra.util.task import ConsoleTaskMonitor
class PdbTypeImporter:
def __init__(
self, api: FlatProgramAPI, extraction: PdbExtractionForGhidraMigration
):
self.api = api
self.extraction = extraction
self.handled_structs: set[str] = (
set()
) # tracks the types we have already imported, otherwise we keep overwriting finished work
@property
def types(self):
return self.extraction.compare.cv.types
def _import_class_or_struct(self, type_in_pdb: dict[str, Any]) -> DataType:
field_list_type = type_in_pdb.get("field_list_type")
if field_list_type is None:
raise TypeNotFoundError(
f"Found a referenced missing type that is not a class or lacks a field_list_type: {type_in_pdb}"
)
field_list = self.types.keys[field_list_type.lower()]
logger.debug("Found class: %s", type_in_pdb)
class_size: int = type_in_pdb["size"]
class_name_with_namespace: str = sanitize_class_name(type_in_pdb["name"])
if class_name_with_namespace in self.handled_structs:
logger.debug(
"Class has been handled or is being handled: %s",
class_name_with_namespace,
)
return get_ghidra_type(self.api, class_name_with_namespace)
# Add as soon as we start to avoid infinite recursion
self.handled_structs.add(class_name_with_namespace)
# Create class / namespace if it does not exist
colon_split = class_name_with_namespace.split("::")
class_name = colon_split[-1]
try:
get_ghidra_namespace(self.api, colon_split)
logger.debug("Found existing class/namespace %s", class_name_with_namespace)
except ClassOrNamespaceNotFoundInGhidraError:
logger.info("Creating class/namespace %s", class_name_with_namespace)
class_name = colon_split.pop()
parent_namespace = create_ghidra_namespace(self.api, colon_split)
self.api.createClass(parent_namespace, class_name)
# Create type if it does not exist
try:
data_type = get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"Found existing data type %s under category path %s",
class_name_with_namespace,
data_type.getCategoryPath(),
)
except TypeNotFoundInGhidraError:
# Create a new struct data type
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
logger.info("Created new data type %s", class_name_with_namespace)
assert isinstance(
data_type, StructureInternal
), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}"
if (old_size := data_type.getLength()) != class_size:
logger.warning(
"Existing class %s had incorrect size %d. Setting to %d...",
class_name_with_namespace,
old_size,
class_size,
)
# TODO: Implement comparison to expected layout
# We might not need that, but it helps to not break stuff if we run into an error
logger.info("Adding class data type %s", class_name_with_namespace)
logger.debug("Class information: %s", type_in_pdb)
data_type.deleteAll()
data_type.growStructure(class_size)
# this case happened for IUnknown, which linked to an (incorrect) existing library, and some other types as well.
# Unfortunately, we don't get proper error handling for read-only types
if data_type.getLength() != class_size:
logger.warning(
"Failed to modify data type %s. Please remove the existing one by hand and try again.",
class_name_with_namespace,
)
assert (
self.api.getCurrentProgram()
.getDataTypeManager()
.remove(data_type, ConsoleTaskMonitor())
), f"Failed to delete and re-create data type {class_name_with_namespace}"
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
assert isinstance(data_type, StructureInternal) # for type checking
# Delete existing components - likely not needed when using replaceAtOffset exhaustively
# for component in data_type.getComponents():
# data_type.deleteAtOffset(component.getOffset())
# can be missing when no new fields are declared
components: list[dict[str, Any]] = field_list.get("members") or []
super_type = field_list.get("super")
if super_type is not None:
components.insert(0, {"type": super_type, "offset": 0, "name": "base"})
for component in components:
ghidra_type = self.pdb_to_ghidra_type(component["type"])
logger.debug("Adding component to class: %s", component)
# XXX: temporary exception handling to get better logs
try:
data_type.replaceAtOffset(
component["offset"], ghidra_type, -1, component["name"], None
)
except Exception as e:
raise Exception(f"Error importing {type_in_pdb}") from e
logger.info("Finished importing class %s", class_name_with_namespace)
return data_type
def pdb_to_ghidra_type(self, type_index: str) -> DataType:
"""
Experimental new type converter to get rid of the intermediate step PDB -> C++ -> Ghidra
@param type_index Either a scalar type like `T_INT4(...)` or a PDB reference like `0x10ba`
"""
# scalar type
type_index_lower = type_index.lower()
if type_index_lower.startswith("t_"):
if (
match := self.extraction.scalar_type_regex.match(type_index_lower)
) is None:
raise TypeNotFoundError(f"Type has unexpected format: {type_index}")
scalar_cpp_type = self.extraction.scalar_type_to_cpp(
match.group("typename")
)
return get_ghidra_type(self.api, scalar_cpp_type)
try:
type_pdb = self.extraction.compare.cv.types.keys[type_index_lower]
except KeyError as e:
raise TypeNotFoundError(
f"Failed to find referenced type {type_index_lower}"
) from e
type_category = type_pdb["type"]
if type_category == "LF_POINTER":
return add_pointer_type(
self.api, self.pdb_to_ghidra_type(type_pdb["element_type"])
)
if type_category in ["LF_CLASS", "LF_STRUCTURE"]:
if type_pdb.get("is_forward_ref", False):
logger.debug(
"Following forward reference from %s to %s",
type_index,
type_pdb["udt"],
)
return self.pdb_to_ghidra_type(type_pdb["udt"])
return self._import_class_or_struct(type_pdb)
if type_category == "LF_ARRAY":
# TODO: See how well this interacts with arrays in functions
# We treat arrays like pointers because we don't distinguish them in Ghidra
logger.debug("Encountered array: %s", type_pdb)
inner_type = self.pdb_to_ghidra_type(type_pdb["array_type"])
# TODO: Insert size / consider switching to pointer if not applicable
return ArrayDataType(inner_type, 0, 0)
if type_category == "LF_ENUM":
logger.warning(
"Replacing enum by underlying type (not implemented yet): %s", type_pdb
)
return self.pdb_to_ghidra_type(type_pdb["underlying_type"])
if type_category == "LF_MODIFIER":
logger.warning("Not sure what a modifier is: %s", type_pdb)
# not sure what this actually is, take what it references
return self.pdb_to_ghidra_type(type_pdb["modifies"])
if type_category == "LF_PROCEDURE":
logger.info(
"Function-valued argument or return type will be replaced by void pointer: %s",
type_pdb,
)
return get_ghidra_type(self.api, "void")
if type_category == "LF_UNION":
if type_pdb.get("is_forward_ref", False):
return self.pdb_to_ghidra_type(type_pdb["udt"])
try:
logger.debug("Dereferencing union %s", type_pdb)
union_type = get_ghidra_type(self.api, type_pdb["name"])
assert (
union_type.getLength() == type_pdb["size"]
), f"Wrong size of existing union type '{type_pdb['name']}': expected {type_pdb["size"]}, got {union_type.getLength()}"
return union_type
except TypeNotFoundInGhidraError as e:
raise TypeNotImplementedError(
f"Writing union types is not supported. Please add by hand: {type_pdb}"
) from e
raise TypeNotImplementedError(type_pdb)

View File

@ -88,6 +88,11 @@ def __init__(
self._match_thunks() self._match_thunks()
self._find_vtordisp() self._find_vtordisp()
@property
def db(self):
"""Newer code needs to access this field, legacy code uses _db"""
return self._db
def _load_cvdump(self): def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file) logger.info("Parsing %s ...", self.pdb_file)
self.cv = ( self.cv = (
@ -161,7 +166,10 @@ def _load_cvdump(self):
addr, sym.node_type, sym.name(), sym.decorated_name, sym.size() addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
) )
for (section, offset), (filename, line_no) in self.cvdump_analysis.verified_lines.items(): for (section, offset), (
filename,
line_no,
) in self.cvdump_analysis.verified_lines.items():
addr = self.recomp_bin.get_abs_addr(section, offset) addr = self.recomp_bin.get_abs_addr(section, offset)
self._lines_db.add_line(filename, line_no, addr) self._lines_db.add_line(filename, line_no, addr)

View File

@ -216,6 +216,9 @@ class CvdumpTypesParser:
re.compile(r"^\s*enum name = (?P<name>.+)$"), re.compile(r"^\s*enum name = (?P<name>.+)$"),
re.compile(r"^\s*UDT\((?P<udt>0x\w+)\)$"), re.compile(r"^\s*UDT\((?P<udt>0x\w+)\)$"),
] ]
LF_UNION_LINE = re.compile(
r".*field list type (?P<field_type>0x\w+),.*Size = (?P<size>\d+)\s*,class name = (?P<name>(?:[^,]|,\S)+),\s.*UDT\((?P<udt>0x\w+)\)"
)
MODES_OF_INTEREST = { MODES_OF_INTEREST = {
"LF_ARRAY", "LF_ARRAY",
@ -228,6 +231,7 @@ class CvdumpTypesParser:
"LF_ARGLIST", "LF_ARGLIST",
"LF_MFUNCTION", "LF_MFUNCTION",
"LF_PROCEDURE", "LF_PROCEDURE",
"LF_UNION",
} }
def __init__(self) -> None: def __init__(self) -> None:
@ -298,7 +302,9 @@ def _mock_array_members(self, type_obj: Dict) -> List[FieldListItem]:
raise CvdumpIntegrityError("No array element type") raise CvdumpIntegrityError("No array element type")
array_element_size = self.get(array_type).size array_element_size = self.get(array_type).size
assert array_element_size is not None, "Encountered an array whose type has no size" assert (
array_element_size is not None
), "Encountered an array whose type has no size"
n_elements = type_obj["size"] // array_element_size n_elements = type_obj["size"] // array_element_size
@ -399,7 +405,9 @@ def get_scalars_gapless(self, type_key: str) -> List[ScalarType]:
obj = self.get(type_key) obj = self.get(type_key)
total_size = obj.size total_size = obj.size
assert total_size is not None, "Called get_scalar_gapless() on a type without size" assert (
total_size is not None
), "Called get_scalar_gapless() on a type without size"
scalars = self.get_scalars(type_key) scalars = self.get_scalars(type_key)
@ -506,6 +514,9 @@ def read_line(self, line: str):
elif self.mode == "LF_ENUM": elif self.mode == "LF_ENUM":
self.read_enum_line(line) self.read_enum_line(line)
elif self.mode == "LF_UNION":
self.read_union_line(line)
else: else:
# Check for exhaustiveness # Check for exhaustiveness
logger.error("Unhandled data in mode: %s", self.mode) logger.error("Unhandled data in mode: %s", self.mode)
@ -610,3 +621,14 @@ def parse_enum_attribute(self, attribute: str) -> dict[str, Any]:
return {"is_forward_ref": True} return {"is_forward_ref": True}
logger.error("Unknown attribute in enum: %s", attribute) logger.error("Unknown attribute in enum: %s", attribute)
return {} return {}
def read_union_line(self, line: str):
"""This is a rather barebones handler, only parsing the size"""
if (match := self.LF_UNION_LINE.match(line)) is None:
raise AssertionError(f"Unhandled in union: {line}")
self._set("name", match.group("name"))
if match.group("field_type") == "0x0000":
self._set("is_forward_ref", True)
self._set("size", int(match.group("size")))
self._set("udt", normalize_type_id(match.group("udt")))