Replace HexRaysCodeXplorer and HexraysInvertIf with HexRaysPyTools

This commit is contained in:
ecx86
2018-08-05 01:08:36 -04:00
parent 035b4ae117
commit 454ae02a36
22 changed files with 5639 additions and 0 deletions

View File

@@ -0,0 +1,583 @@
import logging
import idaapi
import idc
from Core.Helper import to_hex, get_member_name, get_func_argument_info, get_funcs_calling_address, is_imported_ea
logger = logging.getLogger(__name__)
SETTING_START_FROM_CURRENT_EXPR = True
def decompile_function(address):
try:
cfunc = idaapi.decompile(address)
if cfunc:
return cfunc
except idaapi.DecompilationFailure:
pass
logger.warn("IDA failed to decompile function at 0x{address:08X}".format(address=address))
class ScanObject(object):
def __init__(self):
self.ea = idaapi.BADADDR
self.name = None
self.tinfo = None
self.id = 0
@staticmethod
def create(cfunc, arg):
# Creates object suitable for scaning either from cexpr_t or ctree_item_t
if isinstance(arg, idaapi.ctree_item_t):
lvar = arg.get_lvar()
if lvar:
index = list(cfunc.get_lvars()).index(lvar)
result = VariableObject(lvar, index)
if arg.e:
result.ea = ScanObject.get_expression_address(cfunc, arg.e)
return result
cexpr = arg.e
else:
cexpr = arg
if cexpr.op == idaapi.cot_var:
lvar = cfunc.get_lvars()[cexpr.v.idx]
result = VariableObject(lvar, cexpr.v.idx)
result.ea = ScanObject.get_expression_address(cfunc, cexpr)
return result
elif cexpr.op == idaapi.cot_memptr:
t = cexpr.x.type.get_pointed_object()
result = StructPtrObject(t.dstr(), cexpr.m)
result.name = get_member_name(t, cexpr.m)
elif cexpr.op == idaapi.cot_memref:
t = cexpr.x.type
result = StructRefObject(t.dstr(), cexpr.m)
result.name = get_member_name(t, cexpr.m)
elif cexpr.op == idaapi.cot_obj:
result = GlobalVariableObject(cexpr.obj_ea)
result.name = idaapi.get_short_name(cexpr.obj_ea)
else:
return
result.tinfo = cexpr.type
result.ea = ScanObject.get_expression_address(cfunc, cexpr)
return result
@staticmethod
def get_expression_address(cfunc, cexpr):
expr = cexpr
while expr and expr.ea == idaapi.BADADDR:
expr = expr.to_specific_type
expr = cfunc.body.find_parent_of(expr)
assert expr is not None
return expr.ea
def __hash__(self):
return hash((self.id, self.name))
def __eq__(self, other):
return self.id == other.id and self.name == other.name
def __repr__(self):
return self.name
SO_LOCAL_VARIABLE = 1 # cexpr.op == idaapi.cot_var
SO_STRUCT_POINTER = 2 # cexpr.op == idaapi.cot_memptr
SO_STRUCT_REFERENCE = 3 # cexpr.op == idaapi.cot_memref
SO_GLOBAL_OBJECT = 4 # cexpr.op == idaapi.cot_obj
SO_CALL_ARGUMENT = 5 # cexpr.op == idaapi.cot_call
SO_MEMORY_ALLOCATOR = 6
SO_RETURNED_OBJECT = 7
class VariableObject(ScanObject):
# Represents `var` expression
def __init__(self, lvar, index):
super(VariableObject, self).__init__()
self.lvar = lvar
self.tinfo = lvar.type()
self.name = lvar.name
self.index = index
self.id = SO_LOCAL_VARIABLE
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_var and cexpr.v.idx == self.index
class StructPtrObject(ScanObject):
# Represents `x->m` expression
def __init__(self, struct_name, offset):
super(StructPtrObject, self).__init__()
self.struct_name = struct_name
self.offset = offset
self.id = SO_STRUCT_POINTER
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_memptr and cexpr.m == self.offset and \
cexpr.x.type.get_pointed_object().dstr() == self.struct_name
class StructRefObject(ScanObject):
# Represents `x.m` expression
def __init__(self, struct_name, offset):
super(StructRefObject, self).__init__()
self.__struct_name = struct_name
self.__offset = offset
self.id = SO_STRUCT_REFERENCE
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_memref and cexpr.m == self.__offset and cexpr.x.type.dstr() == self.__struct_name
class GlobalVariableObject(ScanObject):
# Represents global object
def __init__(self, object_address):
super(GlobalVariableObject, self).__init__()
self.obj_ea = object_address
self.id = SO_GLOBAL_OBJECT
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_obj and self.obj_ea == cexpr.obj_ea
class CallArgObject(ScanObject):
# Represents call of function and specified argument
def __init__(self, func_address, arg_idx):
super(CallArgObject, self).__init__()
self.__func_ea = func_address
self.__arg_idx = arg_idx
self.id = SO_CALL_ARGUMENT
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_call and cexpr.x.obj_ea == self.__func_ea
def create_scan_obj(self, cfunc, cexpr):
e = cexpr.a[self.__arg_idx]
while e.op in (idaapi.cot_cast, idaapi.cot_ref, idaapi.cot_add, idaapi.cot_sub, idaapi.cot_idx):
e = e.x
return ScanObject.create(cfunc, e)
@staticmethod
def create(cfunc, arg_idx):
result = CallArgObject(cfunc.entry_ea, arg_idx)
result.name = cfunc.get_lvars()[arg_idx].name
result.tinfo = cfunc.type
return result
def __repr__(self):
return "{}"
class ReturnedObject(ScanObject):
# Represents value returned by function
def __init__(self, func_address):
super(ReturnedObject, self).__init__()
self.__func_ea = func_address
self.id = SO_RETURNED_OBJECT
def is_target(self, cexpr):
return cexpr.op == idaapi.cot_call and cexpr.x.obj_ea == self.__func_ea
class MemoryAllocationObject(ScanObject):
# Represents `operator new()' or `malloc'
def __init__(self, name, size):
super(MemoryAllocationObject, self).__init__()
self.name = name
self.size = size
self.id = SO_MEMORY_ALLOCATOR
@staticmethod
def create(cfunc, cexpr):
if cexpr.op == idaapi.cot_call:
e = cexpr
elif cexpr.op == idaapi.cot_cast and cexpr.x.op == idaapi.cot_call:
e = cexpr.x
else:
return
func_name = idaapi.get_short_name(e.x.obj_ea)
if "malloc" in func_name or "operator new" in func_name:
carg = e.a[0]
if carg.op == idaapi.cot_num:
size = carg.numval()
else:
size = 0
result = MemoryAllocationObject(func_name, size)
result.ea = ScanObject.get_expression_address(cfunc, e)
return result
ASSIGNMENT_RIGHT = 1
ASSIGNMENT_LEFT = 2
class ObjectVisitor(idaapi.ctree_parentee_t):
def __init__(self, cfunc, obj, data, skip_until_object):
super(ObjectVisitor, self).__init__()
self._cfunc = cfunc
self._objects = [obj]
self._init_obj = obj
self._data = data
self._start_ea = obj.ea
self._skip = skip_until_object if self._start_ea != idaapi.BADADDR else False
self.crippled = False
def process(self):
self.apply_to(self._cfunc.body, None)
def set_callbacks(self, manipulate=None):
if manipulate:
self.__manipulate = manipulate.__get__(self, ObjectDownwardsVisitor)
def _get_line(self):
for p in reversed(self.parents):
if not p.is_expr():
return idaapi.tag_remove(p.print1(self._cfunc.__ref__()))
AssertionError("Parent instruction is not found")
def _manipulate(self, cexpr, obj):
"""
Method called for every object having assignment relationship with starter object. This method should be
reimplemented in order to do something useful
:param cexpr: idaapi_cexpr_t
:param id: one of the SO_* constants
:return: None
"""
self.__manipulate(cexpr, obj)
def __manipulate(self, cexpr, obj):
logger.debug("Expression {} at {} Id - {}".format(cexpr.opname, to_hex(self._find_asm_address(cexpr)), obj.id))
class ObjectDownwardsVisitor(ObjectVisitor):
def __init__(self, cfunc, obj, data=None, skip_until_object=False):
super(ObjectDownwardsVisitor, self).__init__(cfunc, obj, data, skip_until_object)
self.cv_flags |= idaapi.CV_POST
def visit_expr(self, cexpr):
if self._skip:
if self._is_initial_object(cexpr):
self._skip = False
else:
return 0
if cexpr.op != idaapi.cot_asg:
return 0
x_cexpr = cexpr.x
if cexpr.y.op == idaapi.cot_cast:
y_cexpr = cexpr.y.x
else:
y_cexpr = cexpr.y
for obj in self._objects:
if obj.is_target(x_cexpr):
if self.__is_object_overwritten(x_cexpr, obj, y_cexpr):
logger.info("Removed object {} from scanning at {}".format(
obj, to_hex(self._find_asm_address(x_cexpr))))
self._objects.remove(obj)
return 0
elif obj.is_target(y_cexpr):
new_obj = ScanObject.create(self._cfunc, x_cexpr)
if new_obj:
self._objects.append(new_obj)
return 0
return 0
def leave_expr(self, cexpr):
if self._skip:
return 0
for obj in self._objects:
if obj.is_target(cexpr) and obj.id != SO_RETURNED_OBJECT:
self._manipulate(cexpr, obj)
return 0
return 0
def _is_initial_object(self, cexpr):
if cexpr.op == idaapi.cot_asg:
cexpr = cexpr.y
if cexpr.op == idaapi.cot_cast:
cexpr = cexpr.x
return self._init_obj.is_target(cexpr) and self._find_asm_address(cexpr) == self._start_ea
def __is_object_overwritten(self, x_cexpr, obj, y_cexpr):
if len(self._objects) < 2:
return False
if y_cexpr.op == idaapi.cot_cast:
e = y_cexpr.x
else:
e = y_cexpr
if e.op != idaapi.cot_call or len(e.a) == 0:
return True
for obj in self._objects:
if obj.is_target(e. a[0]):
return False
return True
class ObjectUpwardsVisitor(ObjectVisitor):
STAGE_PREPARE = 1
STAGE_PARSING = 2
def __init__(self, cfunc, obj, data=None, skip_after_object=False):
super(ObjectUpwardsVisitor, self).__init__(cfunc, obj, data, skip_after_object)
self._stage = self.STAGE_PREPARE
self._tree = {}
self._call_obj = obj if obj.id == SO_CALL_ARGUMENT else None
def visit_expr(self, cexpr):
if self._stage == self.STAGE_PARSING:
return 0
if self._call_obj and self._call_obj.is_target(cexpr):
obj = self._call_obj.create_scan_obj(self._cfunc, cexpr)
if obj:
self._objects.append(obj)
return 0
if cexpr.op != idaapi.cot_asg:
return 0
x_cexpr = cexpr.x
if cexpr.y.op == idaapi.cot_cast:
y_cexpr = cexpr.y.x
else:
y_cexpr = cexpr.y
obj_left = ScanObject.create(self._cfunc, x_cexpr)
obj_right = ScanObject.create(self._cfunc, y_cexpr)
if obj_left and obj_right:
self.__add_object_assignment(obj_left, obj_right)
if self._skip and self._is_initial_object(cexpr):
return 1
return 0
def leave_expr(self, cexpr):
if self._stage == self.STAGE_PREPARE:
return 0
if self._skip and self._is_initial_object(cexpr):
self._manipulate(cexpr, self._init_obj)
return 1
for obj in self._objects:
if obj.is_target(cexpr):
self._manipulate(cexpr, obj)
return 0
return 0
def process(self):
self._stage = self.STAGE_PREPARE
self.cv_flags &= ~idaapi.CV_POST
super(ObjectUpwardsVisitor, self).process()
self._stage = self.STAGE_PARSING
self.cv_flags |= idaapi.CV_POST
self.__prepare()
super(ObjectUpwardsVisitor, self).process()
def _is_initial_object(self, cexpr):
return self._init_obj.is_target(cexpr) and self._find_asm_address(cexpr) == self._start_ea
def __add_object_assignment(self, from_obj, to_obj):
if from_obj in self._tree:
self._tree[from_obj].add(to_obj)
else:
self._tree[from_obj] = {to_obj}
def __prepare(self):
result = set()
todo = set(self._objects)
while todo:
obj = todo.pop()
result.add(obj)
if obj.id == SO_CALL_ARGUMENT or obj not in self._tree:
continue
o = self._tree[obj]
todo |= o - result
result |= o
self._objects = list(result)
self._tree.clear()
class RecursiveObjectVisitor(ObjectVisitor):
def __init__(self, cfunc, obj, data=None, skip_until_object=False, visited=None):
super(RecursiveObjectVisitor, self).__init__(cfunc, obj, data, skip_until_object)
self._visited = visited if visited else set()
self._new_for_visit = set()
self.crippled = False
self._arg_idx = -1
self._debug_scan_tree = {}
self.__debug_scan_tree_root = idc.Name(self._cfunc.entry_ea)
self.__debug_message = []
def visit_expr(self, cexpr):
return super(RecursiveObjectVisitor, self).visit_expr(cexpr)
def set_callbacks(self, manipulate=None, start=None, start_iteration=None, finish=None, finish_iteration=None):
super(RecursiveObjectVisitor, self).set_callbacks(manipulate)
if start:
self._start = start.__get__(self, RecursiveObjectDownwardsVisitor)
if start_iteration:
self._start_iteration = start_iteration.__get__(self, RecursiveObjectDownwardsVisitor)
if finish:
self._finish = finish.__get__(self, RecursiveObjectDownwardsVisitor)
if finish_iteration:
self._finish_iteration = finish_iteration.__get__(self, RecursiveObjectDownwardsVisitor)
def prepare_new_scan(self, cfunc, arg_idx, obj, skip=False):
self._cfunc = cfunc
self._arg_idx = arg_idx
self._objects = [obj]
self._init_obj = obj
self._skip = False
self.crippled = self.__is_func_crippled()
def process(self):
self._start()
self._recursive_process()
self._finish()
self.dump_scan_tree()
def dump_scan_tree(self):
self.__prepare_debug_message()
logger.info("{}\n---------------".format("\n".join(self.__debug_message)))
def __prepare_debug_message(self, key=None, level=1):
if key is None:
key = (self.__debug_scan_tree_root, -1)
self.__debug_message.append("--- Scan Tree---\n{}".format(self.__debug_scan_tree_root))
if key in self._debug_scan_tree:
for func_name, arg_idx in self._debug_scan_tree[key]:
prefix = " | " * (level - 1) + " |_ "
self.__debug_message.append("{}{} (idx: {})".format(prefix, func_name, arg_idx))
self.__prepare_debug_message((func_name, arg_idx), level + 1)
def _recursive_process(self):
self._start_iteration()
super(RecursiveObjectVisitor, self).process()
self._finish_iteration()
def _manipulate(self, cexpr, obj):
self._check_call(cexpr)
super(RecursiveObjectVisitor, self)._manipulate(cexpr, obj)
def _check_call(self, cexpr):
raise NotImplemented
def _add_visit(self, func_ea, arg_idx):
if (func_ea, arg_idx) not in self._visited:
self._visited.add((func_ea, arg_idx))
self._new_for_visit.add((func_ea, arg_idx))
return True
return False
def _add_scan_tree_info(self, func_ea, arg_idx):
head_node = (idc.Name(self._cfunc.entry_ea), self._arg_idx)
tail_node = (idc.Name(func_ea), arg_idx)
if head_node in self._debug_scan_tree:
self._debug_scan_tree[head_node].add(tail_node)
else:
self._debug_scan_tree[head_node] = {tail_node}
def _start(self):
""" Called at the beginning of visiting """
pass
def _start_iteration(self):
""" Called every time new function visiting started """
pass
def _finish(self):
""" Called after all visiting happened """
pass
def _finish_iteration(self):
""" Called every time new function visiting finished """
pass
def __is_func_crippled(self):
# Check if function is just call to another function
b = self._cfunc.body.cblock
if b.size() == 1:
e = b.at(0)
return e.op == idaapi.cit_return or (e.op == idaapi.cit_expr and e.cexpr.op == idaapi.cot_call)
return False
class RecursiveObjectDownwardsVisitor(RecursiveObjectVisitor, ObjectDownwardsVisitor):
def __init__(self, cfunc, obj, data=None, skip_until_object=False, visited=None):
super(RecursiveObjectDownwardsVisitor, self).__init__(cfunc, obj, data, skip_until_object, visited)
def _check_call(self, cexpr):
parent = self.parent_expr()
grandparent = self.parents.at(self.parents.size() - 2)
if parent.op == idaapi.cot_call:
call_cexpr = parent
arg_cexpr = cexpr
elif parent.op == idaapi.cot_cast and grandparent.op == idaapi.cot_call:
call_cexpr = grandparent.cexpr
arg_cexpr = parent
else:
return
idx, _ = get_func_argument_info(call_cexpr, arg_cexpr)
func_ea = call_cexpr.x.obj_ea
if func_ea == idaapi.BADADDR:
return
if self._add_visit(func_ea, idx):
self._add_scan_tree_info(func_ea, idx)
def _recursive_process(self):
super(RecursiveObjectDownwardsVisitor, self)._recursive_process()
while self._new_for_visit:
func_ea, arg_idx = self._new_for_visit.pop()
if is_imported_ea(func_ea):
continue
cfunc = decompile_function(func_ea)
if cfunc:
assert arg_idx < len(cfunc.get_lvars()), "Wrong argument at func {}".format(to_hex(func_ea))
obj = VariableObject(cfunc.get_lvars()[arg_idx], arg_idx)
self.prepare_new_scan(cfunc, arg_idx, obj)
self._recursive_process()
class RecursiveObjectUpwardsVisitor(RecursiveObjectVisitor, ObjectUpwardsVisitor):
def __init__(self, cfunc, obj, data=None, skip_after_object=False, visited=None):
super(RecursiveObjectUpwardsVisitor, self).__init__(cfunc, obj, data, skip_after_object, visited)
def prepare_new_scan(self, cfunc, arg_idx, obj, skip=False):
super(RecursiveObjectUpwardsVisitor, self).prepare_new_scan(cfunc, arg_idx, obj, skip)
self._call_obj = obj if obj.id == SO_CALL_ARGUMENT else None
def _check_call(self, cexpr):
if cexpr.op == idaapi.cot_var and self._cfunc.get_lvars()[cexpr.v.idx].is_arg_var:
func_ea = self._cfunc.entry_ea
arg_idx = cexpr.v.idx
if self._add_visit(func_ea, arg_idx):
for callee_ea in get_funcs_calling_address(func_ea):
self._add_scan_tree_info(callee_ea, arg_idx)
def _recursive_process(self):
super(RecursiveObjectUpwardsVisitor, self)._recursive_process()
while self._new_for_visit:
new_visit = list(self._new_for_visit)
self._new_for_visit.clear()
for func_ea, arg_idx in new_visit:
funcs = get_funcs_calling_address(func_ea)
obj = CallArgObject.create(idaapi.decompile(func_ea), arg_idx)
for callee_ea in funcs:
cfunc = decompile_function(callee_ea)
if cfunc:
self.prepare_new_scan(cfunc, arg_idx, obj, False)
super(RecursiveObjectUpwardsVisitor, self)._recursive_process()