refactor types and make them Python 3.9 compatible

This commit is contained in:
jonschz 2024-05-23 20:26:54 +02:00
parent 17b101d8fc
commit 8678ad72c4
7 changed files with 30 additions and 31 deletions

View File

@ -188,7 +188,9 @@ def process_functions(isle_compare: "IsleCompare"):
cause = e.args[0] cause = e.args[0]
if CancelledException is not None and isinstance(cause, CancelledException): if CancelledException is not None and isinstance(cause, CancelledException):
# let Ghidra's CancelledException pass through # let Ghidra's CancelledException pass through
raise logging.critical("Import aborted by the user.")
return
log_and_track_failure(cause, unexpected=True) log_and_track_failure(cause, unexpected=True)
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
log_and_track_failure(e, unexpected=True) log_and_track_failure(e, unexpected=True)

View File

@ -8,10 +8,10 @@ def __str__(self):
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception): class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy): # type: (list[str]) -> None def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy) super().__init__(namespaceHierachy)
def get_namespace_str(self): # type: () -> str def get_namespace_str(self) -> str:
return "::".join(self.args[0]) return "::".join(self.args[0])
def __str__(self): def __str__(self):

View File

@ -17,7 +17,7 @@
from ghidra.program.model.symbol import Namespace from ghidra.program.model.symbol import Namespace
def get_ghidra_type(api, type_name): # type: (FlatProgramAPI, str) -> DataType def get_ghidra_type(api: FlatProgramAPI, type_name: str):
""" """
Searches for the type named `typeName` in Ghidra. Searches for the type named `typeName` in Ghidra.
@ -44,7 +44,7 @@ def get_ghidra_type(api, type_name): # type: (FlatProgramAPI, str) -> DataType
raise MultipleTypesFoundInGhidraError(type_name, result) raise MultipleTypesFoundInGhidraError(type_name, result)
def add_pointer_type(api, pointee): # type: (FlatProgramAPI, DataType) -> DataType def add_pointer_type(api: FlatProgramAPI, pointee: DataType):
data_type = PointerDataType(pointee) data_type = PointerDataType(pointee)
data_type.setCategoryPath(pointee.getCategoryPath()) data_type.setCategoryPath(pointee.getCategoryPath())
api.getCurrentProgram().getDataTypeManager().addDataType( api.getCurrentProgram().getDataTypeManager().addDataType(
@ -54,9 +54,7 @@ def add_pointer_type(api, pointee): # type: (FlatProgramAPI, DataType) -> DataT
return data_type return data_type
def get_ghidra_namespace( def get_ghidra_namespace(api: FlatProgramAPI, namespace_hierachy: list[str]) -> Namespace:
api, namespace_hierachy
): # type: (FlatProgramAPI, list[str]) -> Namespace
namespace = api.getCurrentProgram().getGlobalNamespace() namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy: for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part) namespace = api.getNamespace(namespace, part)

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
import re import re
from typing import Any from typing import Any, Optional
import logging import logging
from isledecomp.cvdump.symbols import SymbolsEntry from isledecomp.cvdump.symbols import SymbolsEntry
@ -38,7 +38,7 @@ class FunctionSignature:
call_type: str call_type: str
arglist: list[str] arglist: list[str]
return_type: str return_type: str
class_type: dict[str, Any] | None class_type: Optional[dict[str, Any]]
stack_symbols: list[CppStackOrRegisterSymbol] stack_symbols: list[CppStackOrRegisterSymbol]
@ -67,7 +67,7 @@ def scalar_type_to_cpp(self, scalar_type: str) -> str:
return f"{self.scalar_type_to_cpp(scalar_type[3:])} *" return f"{self.scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type) return self._scalar_type_map.get(scalar_type, scalar_type)
def lookup_type(self, type_name: str | None) -> dict[str, Any] | None: def lookup_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return ( return (
None None
if type_name is None if type_name is None
@ -114,7 +114,7 @@ def type_to_cpp_type_name(self, type_name: str) -> str:
logger.error("Unknown type: %s", dereferenced) logger.error("Unknown type: %s", dereferenced)
return "<<parsing error>>" return "<<parsing error>>"
def get_func_signature(self, fn: "SymbolsEntry") -> FunctionSignature | None: def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
function_type_str = fn.func_type function_type_str = fn.func_type
if function_type_str == "T_NOTYPE(0000)": if function_type_str == "T_NOTYPE(0000)":
logger.debug( logger.debug(
@ -181,7 +181,7 @@ def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]:
def handle_matched_function( def handle_matched_function(
self, match_info: MatchInfo self, match_info: MatchInfo
) -> tuple[MatchInfo, FunctionSignature] | None: ) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None assert match_info.orig_addr is not None
match_options = self.compare._db.get_match_options(match_info.orig_addr) match_options = self.compare._db.get_match_options(match_info.orig_addr)
assert match_options is not None assert match_options is not None

View File

@ -4,6 +4,7 @@
# pyright: reportMissingModuleSource=false # pyright: reportMissingModuleSource=false
import logging import logging
from typing import Optional
from ghidra.program.model.listing import Function, Parameter from ghidra.program.model.listing import Function, Parameter
from ghidra.program.flatapi import FlatProgramAPI from ghidra.program.flatapi import FlatProgramAPI
@ -33,9 +34,9 @@ class PdbFunctionWithGhidraObjects:
def __init__( def __init__(
self, self,
fpapi: "FlatProgramAPI", fpapi: FlatProgramAPI,
match_info: "MatchInfo", match_info: MatchInfo,
signature: "FunctionSignature", signature: FunctionSignature,
): ):
self.api = fpapi self.api = fpapi
self.match_info = match_info self.match_info = match_info
@ -74,7 +75,7 @@ def format_proposed_change(self) -> str:
+ f"({', '.join(self.signature.arglist)})" + f"({', '.join(self.signature.arglist)})"
) )
def matches_ghidra_function(self, ghidra_function): # type: (Function) -> bool def matches_ghidra_function(self, ghidra_function: Function) -> bool:
"""Checks whether this function declaration already matches the description in Ghidra""" """Checks whether this function declaration already matches the description in Ghidra"""
name_match = self.name == ghidra_function.getName(False) name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace() namespace_match = self.namespace == ghidra_function.getParentNamespace()
@ -109,12 +110,10 @@ def matches_ghidra_function(self, ghidra_function): # type: (Function) -> bool
and args_match and args_match
) )
def _matches_non_thiscall_parameters( def _matches_non_thiscall_parameters(self, ghidra_function: Function) -> bool:
self, ghidra_function
): # type: (Function) -> bool
return self._parameter_lists_match(ghidra_function.getParameters()) return self._parameter_lists_match(ghidra_function.getParameters())
def _matches_thiscall_parameters(self, ghidra_function: "Function") -> bool: def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool:
ghidra_params = list(ghidra_function.getParameters()) ghidra_params = list(ghidra_function.getParameters())
# remove the `this` argument which we don't generate ourselves # remove the `this` argument which we don't generate ourselves
@ -151,7 +150,7 @@ def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool:
return False return False
return True return True
def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> None def overwrite_ghidra_function(self, ghidra_function: Function):
"""Replace the function declaration in Ghidra by the one derived from C++.""" """Replace the function declaration in Ghidra by the one derived from C++."""
ghidra_function.setName(self.name, SourceType.USER_DEFINED) ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace) ghidra_function.setParentNamespace(self.namespace)
@ -168,7 +167,7 @@ def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> Non
# When we set the parameters, Ghidra will generate the layout. # When we set the parameters, Ghidra will generate the layout.
# Now we read them again and match them against the stack layout in the PDB, # Now we read them again and match them against the stack layout in the PDB,
# both to verify and to set the parameter names. # both to verify and to set the parameter names.
ghidra_parameters: "list[ghidra.program.model.listing.Parameter]" = ghidra_function.getParameters() # type: ignore ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names # Try to add Ghidra function names
for param in ghidra_parameters: for param in ghidra_parameters:
@ -195,7 +194,7 @@ def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> Non
# ) # )
# continue # continue
def _rename_stack_parameter(self, param: "Parameter"): def _rename_stack_parameter(self, param: Parameter):
match = self.get_matching_stack_symbol(param.getStackOffset()) match = self.get_matching_stack_symbol(param.getStackOffset())
if match is None: if match is None:
raise StackOffsetMismatchError( raise StackOffsetMismatchError(
@ -210,7 +209,7 @@ def _rename_stack_parameter(self, param: "Parameter"):
param.setName(match.name, SourceType.USER_DEFINED) param.setName(match.name, SourceType.USER_DEFINED)
def get_matching_stack_symbol(self, stack_offset: int) -> "CppStackSymbol | None": def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]:
return next( return next(
( (
symbol symbol
@ -221,7 +220,7 @@ def get_matching_stack_symbol(self, stack_offset: int) -> "CppStackSymbol | None
None, None,
) )
def get_matching_register_symbol(self, register: str) -> "CppRegisterSymbol | None": def get_matching_register_symbol(self, register: str) -> Optional[CppRegisterSymbol]:
return next( return next(
( (
symbol symbol

View File

@ -33,8 +33,8 @@ class CvdumpNode:
# Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be # Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be
# accurate. # accurate.
section_contribution: Optional[int] = None section_contribution: Optional[int] = None
addr: int | None = None addr: Optional[int] = None
symbol_entry: SymbolsEntry | None = None symbol_entry: Optional[SymbolsEntry] = None
def __init__(self, section: int, offset: int) -> None: def __init__(self, section: int, offset: int) -> None:
self.section = section self.section = section

View File

@ -1,6 +1,6 @@
import logging import logging
import re import re
from typing import NamedTuple from typing import NamedTuple, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ class SymbolsEntry(NamedTuple):
func_type: str func_type: str
name: str name: str
stack_symbols: list[StackOrRegisterSymbol] stack_symbols: list[StackOrRegisterSymbol]
addr: int | None # absolute address, to be set later addr: Optional[int] # absolute address, to be set later
class CvdumpSymbolsParser: class CvdumpSymbolsParser:
@ -81,7 +81,7 @@ def read_line(self, line: str):
return return
symbol_type: str = match.group("symbol_type") symbol_type: str = match.group("symbol_type")
second_part: str | None = match.group("second_part") second_part: Optional[str] = match.group("second_part")
if symbol_type == "S_GPROC32": if symbol_type == "S_GPROC32":
assert second_part is not None assert second_part is not None