feat: Reuse enums instead of recreating them every time

This commit is contained in:
jonschz 2024-06-16 11:24:53 +02:00
parent c8dc77cbf4
commit 56b8c96d6a

View File

@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Callable, TypeVar
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
@ -29,6 +29,7 @@
CategoryPath,
DataType,
DataTypeConflictHandler,
Enum,
EnumDataType,
StructureDataType,
StructureInternal,
@ -47,7 +48,9 @@ def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor):
self.extraction = extraction
# tracks the structs/classes we have already started to import, otherwise we run into infinite recursion
self.handled_structs: set[str] = set()
self.struct_call_stack: list[str] = []
# tracks the enums we have already handled for the sake of efficiency
self.handled_enums: dict[str, Enum] = {}
@property
def types(self):
@ -166,9 +169,13 @@ def _import_enum(self, type_pdb: dict[str, Any]) -> DataType:
field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"])
assert field_list is not None, f"Failed to find field list for enum {type_pdb}"
result = EnumDataType(
CategoryPath("/imported"), type_pdb["name"], underlying_type.getLength()
result = self._get_or_create_enum_data_type(
type_pdb["name"], underlying_type.getLength()
)
# clear existing variant if there are any
for existing_variant in result.getNames():
result.remove(existing_variant)
variants: list[dict[str, Any]] = field_list["variants"]
for variant in variants:
result.add(variant["name"], variant["value"])
@ -259,30 +266,73 @@ def _get_or_create_namespace(self, class_name_with_namespace: str):
parent_namespace = create_ghidra_namespace(self.api, colon_split)
self.api.createClass(parent_namespace, class_name)
def _get_or_create_enum_data_type(
self, enum_type_name: str, enum_type_size: int
) -> Enum:
if (known_enum := self.handled_enums.get(enum_type_name, None)) is not None:
return known_enum
result = self._get_or_create_data_type(
enum_type_name,
"enum",
Enum,
lambda: EnumDataType(
CategoryPath("/imported"), enum_type_name, enum_type_size
),
)
self.handled_enums[enum_type_name] = result
return result
def _get_or_create_struct_data_type(
self, class_name_with_namespace: str, class_size: int
) -> StructureInternal:
return self._get_or_create_data_type(
class_name_with_namespace,
"class/struct",
StructureInternal,
lambda: StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
),
)
T = TypeVar("T", bound=DataType)
def _get_or_create_data_type(
self,
type_name: str,
readable_name_of_type_category: str,
expected_type: type[T],
new_instance_callback: Callable[[], T],
) -> T:
"""
Checks if a data type provided under the given name exists in Ghidra.
Creates one using `new_instance_callback` if there is not.
Also verifies the data type.
Note that the return value of `addDataType()` is not the same instance as the input
even if there is no name collision.
"""
try:
data_type = get_ghidra_type(self.api, class_name_with_namespace)
data_type = get_ghidra_type(self.api, type_name)
logger.debug(
"Found existing data type %s under category path %s",
class_name_with_namespace,
"Found existing %s type %s under category path %s",
readable_name_of_type_category,
type_name,
data_type.getCategoryPath(),
)
except TypeNotFoundInGhidraError:
# Create a new struct data type
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
.addDataType(new_instance_callback(), DataTypeConflictHandler.KEEP_HANDLER)
)
logger.info(
"Created new %s data type %s", readable_name_of_type_category, type_name
)
logger.info("Created new data type %s", class_name_with_namespace)
assert isinstance(
data_type, StructureInternal
), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}"
data_type, expected_type
), f"Found existing type named {type_name} that is not a {readable_name_of_type_category}"
return data_type
def _delete_and_recreate_struct_data_type(