refactor types and make them Python 3.9 compatible

This commit is contained in:
jonschz 2024-05-23 20:26:54 +02:00
parent 17b101d8fc
commit 8678ad72c4
7 changed files with 30 additions and 31 deletions

View File

@ -188,7 +188,9 @@ def process_functions(isle_compare: "IsleCompare"):
cause = e.args[0]
if CancelledException is not None and isinstance(cause, CancelledException):
# let Ghidra's CancelledException pass through
raise
logging.critical("Import aborted by the user.")
return
log_and_track_failure(cause, unexpected=True)
except Exception as e: # pylint: disable=broad-exception-caught
log_and_track_failure(e, unexpected=True)

View File

@ -8,10 +8,10 @@ def __str__(self):
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy): # type: (list[str]) -> None
def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy)
def get_namespace_str(self): # type: () -> str
def get_namespace_str(self) -> str:
return "::".join(self.args[0])
def __str__(self):

View File

@ -17,7 +17,7 @@
from ghidra.program.model.symbol import Namespace
def get_ghidra_type(api, type_name): # type: (FlatProgramAPI, str) -> DataType
def get_ghidra_type(api: FlatProgramAPI, type_name: str):
"""
Searches for the type named `typeName` in Ghidra.
@ -44,7 +44,7 @@ def get_ghidra_type(api, type_name): # type: (FlatProgramAPI, str) -> DataType
raise MultipleTypesFoundInGhidraError(type_name, result)
def add_pointer_type(api, pointee): # type: (FlatProgramAPI, DataType) -> DataType
def add_pointer_type(api: FlatProgramAPI, pointee: DataType):
data_type = PointerDataType(pointee)
data_type.setCategoryPath(pointee.getCategoryPath())
api.getCurrentProgram().getDataTypeManager().addDataType(
@ -54,9 +54,7 @@ def add_pointer_type(api, pointee): # type: (FlatProgramAPI, DataType) -> DataT
return data_type
def get_ghidra_namespace(
api, namespace_hierachy
): # type: (FlatProgramAPI, list[str]) -> Namespace
def get_ghidra_namespace(api: FlatProgramAPI, namespace_hierachy: list[str]) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
import re
from typing import Any
from typing import Any, Optional
import logging
from isledecomp.cvdump.symbols import SymbolsEntry
@ -38,7 +38,7 @@ class FunctionSignature:
call_type: str
arglist: list[str]
return_type: str
class_type: dict[str, Any] | None
class_type: Optional[dict[str, Any]]
stack_symbols: list[CppStackOrRegisterSymbol]
@ -67,7 +67,7 @@ def scalar_type_to_cpp(self, scalar_type: str) -> str:
return f"{self.scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type)
def lookup_type(self, type_name: str | None) -> dict[str, Any] | None:
def lookup_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return (
None
if type_name is None
@ -114,7 +114,7 @@ def type_to_cpp_type_name(self, type_name: str) -> str:
logger.error("Unknown type: %s", dereferenced)
return "<<parsing error>>"
def get_func_signature(self, fn: "SymbolsEntry") -> FunctionSignature | None:
def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
function_type_str = fn.func_type
if function_type_str == "T_NOTYPE(0000)":
logger.debug(
@ -181,7 +181,7 @@ def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]:
def handle_matched_function(
self, match_info: MatchInfo
) -> tuple[MatchInfo, FunctionSignature] | None:
) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None
match_options = self.compare._db.get_match_options(match_info.orig_addr)
assert match_options is not None

View File

@ -4,6 +4,7 @@
# pyright: reportMissingModuleSource=false
import logging
from typing import Optional
from ghidra.program.model.listing import Function, Parameter
from ghidra.program.flatapi import FlatProgramAPI
@ -33,9 +34,9 @@ class PdbFunctionWithGhidraObjects:
def __init__(
self,
fpapi: "FlatProgramAPI",
match_info: "MatchInfo",
signature: "FunctionSignature",
fpapi: FlatProgramAPI,
match_info: MatchInfo,
signature: FunctionSignature,
):
self.api = fpapi
self.match_info = match_info
@ -74,7 +75,7 @@ def format_proposed_change(self) -> str:
+ f"({', '.join(self.signature.arglist)})"
)
def matches_ghidra_function(self, ghidra_function): # type: (Function) -> bool
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
"""Checks whether this function declaration already matches the description in Ghidra"""
name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace()
@ -109,12 +110,10 @@ def matches_ghidra_function(self, ghidra_function): # type: (Function) -> bool
and args_match
)
def _matches_non_thiscall_parameters(
self, ghidra_function
): # type: (Function) -> bool
def _matches_non_thiscall_parameters(self, ghidra_function: Function) -> bool:
return self._parameter_lists_match(ghidra_function.getParameters())
def _matches_thiscall_parameters(self, ghidra_function: "Function") -> bool:
def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool:
ghidra_params = list(ghidra_function.getParameters())
# remove the `this` argument which we don't generate ourselves
@ -151,7 +150,7 @@ def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool:
return False
return True
def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> None
def overwrite_ghidra_function(self, ghidra_function: Function):
"""Replace the function declaration in Ghidra by the one derived from C++."""
ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace)
@ -168,7 +167,7 @@ def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> Non
# When we set the parameters, Ghidra will generate the layout.
# Now we read them again and match them against the stack layout in the PDB,
# both to verify and to set the parameter names.
ghidra_parameters: "list[ghidra.program.model.listing.Parameter]" = ghidra_function.getParameters() # type: ignore
ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names
for param in ghidra_parameters:
@ -195,7 +194,7 @@ def overwrite_ghidra_function(self, ghidra_function): # type: (Function) -> Non
# )
# continue
def _rename_stack_parameter(self, param: "Parameter"):
def _rename_stack_parameter(self, param: Parameter):
match = self.get_matching_stack_symbol(param.getStackOffset())
if match is None:
raise StackOffsetMismatchError(
@ -210,7 +209,7 @@ def _rename_stack_parameter(self, param: "Parameter"):
param.setName(match.name, SourceType.USER_DEFINED)
def get_matching_stack_symbol(self, stack_offset: int) -> "CppStackSymbol | None":
def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]:
return next(
(
symbol
@ -221,7 +220,7 @@ def get_matching_stack_symbol(self, stack_offset: int) -> "CppStackSymbol | None
None,
)
def get_matching_register_symbol(self, register: str) -> "CppRegisterSymbol | None":
def get_matching_register_symbol(self, register: str) -> Optional[CppRegisterSymbol]:
return next(
(
symbol

View File

@ -33,8 +33,8 @@ class CvdumpNode:
# Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be
# accurate.
section_contribution: Optional[int] = None
addr: int | None = None
symbol_entry: SymbolsEntry | None = None
addr: Optional[int] = None
symbol_entry: Optional[SymbolsEntry] = None
def __init__(self, section: int, offset: int) -> None:
self.section = section

View File

@ -1,6 +1,6 @@
import logging
import re
from typing import NamedTuple
from typing import NamedTuple, Optional
logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ class SymbolsEntry(NamedTuple):
func_type: str
name: str
stack_symbols: list[StackOrRegisterSymbol]
addr: int | None # absolute address, to be set later
addr: Optional[int] # absolute address, to be set later
class CvdumpSymbolsParser:
@ -81,7 +81,7 @@ def read_line(self, line: str):
return
symbol_type: str = match.group("symbol_type")
second_part: str | None = match.group("second_part")
second_part: Optional[str] = match.group("second_part")
if symbol_type == "S_GPROC32":
assert second_part is not None